/*
 * Decompiled with CFR 0.152.
 */
package org.bsc.langgraph4j.studio;

import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.ServletConfig;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.io.Reader;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.bsc.async.AsyncGenerator;
import org.bsc.langgraph4j.CompileConfig;
import org.bsc.langgraph4j.CompiledGraph;
import org.bsc.langgraph4j.GraphInput;
import org.bsc.langgraph4j.GraphRepresentation;
import org.bsc.langgraph4j.GraphStateException;
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.RunnableConfig;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.bsc.langgraph4j.serializer.plain_text.PlainTextStateSerializer;
import org.bsc.langgraph4j.serializer.plain_text.jackson.JacksonStateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.StateSnapshot;
import org.bsc.langgraph4j.studio.NodeOutputSerializer;
import org.bsc.langgraph4j.utils.CollectionsUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public interface LangGraphStreamingServer {
    public static final Logger log = LoggerFactory.getLogger(LangGraphStreamingServer.class);

    public CompletableFuture<Void> start() throws Exception;

    public static class GraphInitServlet
    extends HttpServlet {
        Logger log = log;
        final StateGraph<? extends AgentState> stateGraph;
        final ObjectMapper objectMapper = new ObjectMapper();
        InitData initData;

        public GraphInitServlet(StateGraph<? extends AgentState> stateGraph, String title, List<ArgumentMetadata> args) {
            Objects.requireNonNull(stateGraph, "stateGraph cannot be null");
            this.stateGraph = stateGraph;
            this.initData = new InitData(title, null, args);
        }

        public void init(ServletConfig config) throws ServletException {
            super.init(config);
            SimpleModule module = new SimpleModule();
            module.addSerializer(InitData.class, (JsonSerializer)new InitDataSerializer(InitData.class));
            this.objectMapper.registerModule((Module)module);
            try {
                CompiledGraph compiledGraph = this.stateGraph.compile();
                GraphRepresentation graph = compiledGraph.getGraph(GraphRepresentation.Type.MERMAID, null, false);
                this.initData = new InitData(this.initData.title(), graph.content(), this.initData.args(), this.initData.threads());
            }
            catch (GraphStateException ex) {
                throw new ServletException((Throwable)ex);
            }
        }

        protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
            response.setContentType("application/json");
            response.setCharacterEncoding("UTF-8");
            String resultJson = this.objectMapper.writeValueAsString((Object)this.initData);
            this.log.trace("{}", (Object)resultJson);
            PrintWriter writer = response.getWriter();
            writer.println(resultJson);
            writer.close();
        }
    }

    public static class InitDataSerializer
    extends StdSerializer<InitData> {
        Logger log = log;

        protected InitDataSerializer(Class<InitData> t) {
            super(t);
        }

        public void serialize(InitData initData, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException {
            this.log.trace("InitDataSerializer start!");
            jsonGenerator.writeStartObject();
            jsonGenerator.writeStringField("graph", initData.graph());
            jsonGenerator.writeStringField("title", initData.title());
            jsonGenerator.writeObjectField("args", initData.args());
            jsonGenerator.writeArrayFieldStart("threads");
            for (ThreadEntry thread : initData.threads()) {
                jsonGenerator.writeStartArray();
                jsonGenerator.writeString(thread.id());
                jsonGenerator.writeStartArray(thread.entries());
                jsonGenerator.writeEndArray();
                jsonGenerator.writeEndArray();
            }
            jsonGenerator.writeEndArray();
            jsonGenerator.writeEndObject();
        }
    }

    public record InitData(String title, String graph, List<ArgumentMetadata> args, List<ThreadEntry> threads) {
        public InitData(String title, String graph, List<ArgumentMetadata> args) {
            this(title, graph, args, List.of(new ThreadEntry("default", List.of())));
        }
    }

    public record ThreadEntry(String id, List<? extends NodeOutput<? extends AgentState>> entries) {
    }

    public record ArgumentMetadata(String name, ArgumentType type, boolean required, @JsonIgnore Function<Object, Object> converter) {
        public ArgumentMetadata {
            Objects.requireNonNull(name, "name cannot be null");
            Objects.requireNonNull(type, "type cannot be null");
        }

        public ArgumentMetadata(String name, ArgumentType type, boolean required) {
            this(name, type, required, null);
        }

        public static enum ArgumentType {
            STRING,
            IMAGE;

        }
    }

    public static class GraphStreamServlet
    extends HttpServlet {
        final Logger log = log;
        final StateGraph<? extends AgentState> stateGraph;
        final ObjectMapper objectMapper;
        final Map<PersistentConfig, CompiledGraph<? extends AgentState>> graphCache = new HashMap<PersistentConfig, CompiledGraph<? extends AgentState>>();
        final CompileConfig compileConfig;
        final List<ArgumentMetadata> args;

        public GraphStreamServlet(StateGraph<? extends AgentState> stateGraph, CompileConfig compileConfig, List<ArgumentMetadata> args) {
            this.stateGraph = Objects.requireNonNull(stateGraph, "stateGraph cannot be null");
            this.compileConfig = Objects.requireNonNull(compileConfig, "compileConfig cannot be null");
            this.args = Objects.requireNonNull(args, "args cannot be null");
            StateSerializer stateSerializer = stateGraph.getStateSerializer();
            if (stateSerializer instanceof JacksonStateSerializer) {
                JacksonStateSerializer jsonSerializer = (JacksonStateSerializer)stateSerializer;
                this.objectMapper = jsonSerializer.objectMapper().copy();
            } else {
                this.objectMapper = new ObjectMapper();
                this.objectMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
            }
            SimpleModule module = new SimpleModule();
            module.addSerializer(NodeOutput.class, (JsonSerializer)new NodeOutputSerializer());
            this.objectMapper.registerModule((Module)module);
        }

        public void init(ServletConfig config) throws ServletException {
            super.init(config);
        }

        private CompileConfig compileConfig(PersistentConfig config) {
            return this.compileConfig;
        }

        RunnableConfig runnableConfig(PersistentConfig config) {
            return RunnableConfig.builder().threadId(config.threadId()).build();
        }

        private void serializeOutput(PrintWriter writer, String threadId, NodeOutput<? extends AgentState> output) {
            try {
                writer.printf("[ \"%s\",", threadId);
                writer.println();
                String outputAsString = this.objectMapper.writeValueAsString(output);
                writer.println(outputAsString);
                writer.println("]");
            }
            catch (IOException e) {
                this.log.warn("error serializing state", (Throwable)e);
            }
        }

        protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
            response.setHeader("Accept", "application/json");
            response.setContentType("text/plain");
            response.setCharacterEncoding("UTF-8");
            HttpSession session = request.getSession(true);
            Objects.requireNonNull(session, "session cannot be null");
            String threadId = Optional.ofNullable(request.getParameter("thread")).orElseThrow(() -> new IllegalStateException("Missing thread id!"));
            Boolean resume = Optional.ofNullable(request.getParameter("resume")).map(Boolean::parseBoolean).orElse(false);
            PrintWriter writer = response.getWriter();
            AsyncContext asyncContext = request.startAsync();
            try {
                Map candidateDataMap;
                AsyncGenerator generator = null;
                PersistentConfig persistentConfig = new PersistentConfig(session.getId(), threadId);
                CompiledGraph compiledGraph = this.graphCache.get(persistentConfig);
                StateSerializer stateSerializer = this.stateGraph.getStateSerializer();
                if (stateSerializer instanceof PlainTextStateSerializer) {
                    PlainTextStateSerializer textSerializer = (PlainTextStateSerializer)stateSerializer;
                    candidateDataMap = textSerializer.read((Reader)new InputStreamReader((InputStream)request.getInputStream())).data();
                } else {
                    candidateDataMap = (Map)this.objectMapper.readValue((InputStream)request.getInputStream(), (TypeReference)new TypeReference<Map<String, Object>>(){});
                }
                Map<String, Object> dataMap = candidateDataMap.entrySet().stream().map(entry -> {
                    Optional<Object> newValue = this.args.stream().filter(arg -> arg.name().equals(entry.getKey()) && arg.converter() != null).findAny().map(arg -> arg.converter.apply(entry.getValue()));
                    return newValue.map(v -> CollectionsUtils.entryOf((Object)((String)entry.getKey()), (Object)v)).orElse((Map.Entry)entry);
                }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
                if (resume.booleanValue()) {
                    this.log.trace("RESUME REQUEST PREPARE");
                    if (compiledGraph == null) {
                        throw new IllegalStateException("Missing CompiledGraph in session!");
                    }
                    String checkpointId = Optional.ofNullable(request.getParameter("checkpoint")).orElseThrow(() -> new IllegalStateException("Missing checkpoint id!"));
                    String node = request.getParameter("node");
                    RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).checkPointId(checkpointId).nextNode(node).build();
                    StateSnapshot stateSnapshot = compiledGraph.getState(runnableConfig);
                    runnableConfig = stateSnapshot.config();
                    this.log.trace("RESUME UPDATE STATE FORM {} USING CONFIG {}\n{}", new Object[]{node, runnableConfig, dataMap});
                    runnableConfig = compiledGraph.updateState(runnableConfig, dataMap, node);
                    this.log.trace("RESUME REQUEST STREAM {}", (Object)runnableConfig);
                    generator = compiledGraph.streamSnapshots(GraphInput.resume(), runnableConfig);
                } else {
                    this.log.trace("dataMap: {}", dataMap);
                    if (compiledGraph == null) {
                        compiledGraph = this.stateGraph.compile(this.compileConfig(persistentConfig));
                        this.graphCache.put(persistentConfig, (CompiledGraph<? extends AgentState>)compiledGraph);
                    }
                    generator = compiledGraph.streamSnapshots(dataMap, this.runnableConfig(persistentConfig));
                }
                ((CompletableFuture)((CompletableFuture)generator.forEachAsync(s -> {
                    try {
                        this.serializeOutput(writer, threadId, (NodeOutput<? extends AgentState>)s);
                        writer.println();
                        writer.flush();
                        TimeUnit.SECONDS.sleep(1L);
                    }
                    catch (InterruptedException e) {
                        throw new CompletionException(e);
                    }
                }).thenAccept(v -> writer.close())).thenAccept(v -> asyncContext.complete())).exceptionally(e -> {
                    this.log.error("Error streaming", e);
                    writer.close();
                    asyncContext.complete();
                    return null;
                });
            }
            catch (Throwable e2) {
                this.log.error("Error streaming", e2);
                throw new ServletException(e2);
            }
        }
    }

    public record PersistentConfig(String sessionId, String threadId) {
        public PersistentConfig {
            Objects.requireNonNull(sessionId);
        }
    }
}

