/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
import io.modelcontextprotocol.server.McpTransportContext;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.server.transport.IMcpHttpServerTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.KeepAliveScheduler;
import io.modelcontextprotocol.util.Utils;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;
import org.noear.solon.SolonApp;
import org.noear.solon.core.handle.Context;
import org.noear.solon.core.handle.Entity;
import org.noear.solon.core.util.PathUtil;
import org.noear.solon.web.sse.SseEmitter;
import org.noear.solon.web.sse.SseEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class WebRxSseServerTransportProvider
implements McpServerTransportProvider,
IMcpHttpServerTransport {
    private static final Logger logger = LoggerFactory.getLogger(WebRxSseServerTransportProvider.class);
    public static final String MESSAGE_EVENT_TYPE = "message";
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    public static final String DEFAULT_SSE_ENDPOINT = "/sse";
    private final ObjectMapper objectMapper;
    private final String messageEndpoint;
    private final String messageEndpointFull;
    private final String sseEndpoint;
    private McpServerSession.Factory sessionFactory;
    private final ConcurrentHashMap<String, McpServerSession> sessions = new ConcurrentHashMap();
    private final ConcurrentHashMap<String, Context> sessionRequests = new ConcurrentHashMap();
    private McpTransportContextExtractor<Context> contextExtractor;
    private volatile boolean isClosing = false;
    private KeepAliveScheduler keepAliveScheduler;

    @Deprecated
    public WebRxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint, McpTransportContextExtractor<Context> contextExtractor, Duration keepAliveInterval) {
        Assert.notNull(objectMapper, "ObjectMapper must not be null");
        Assert.notNull(baseUrl, "Message base URL must not be null");
        Assert.notNull(messageEndpoint, "Message endpoint must not be null");
        Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
        this.objectMapper = objectMapper;
        this.messageEndpoint = messageEndpoint;
        this.messageEndpointFull = PathUtil.joinUri((String)baseUrl, (String)messageEndpoint);
        this.sseEndpoint = sseEndpoint;
        this.contextExtractor = contextExtractor;
        if (keepAliveInterval != null) {
            this.keepAliveScheduler = KeepAliveScheduler.builder(() -> this.isClosing ? Flux.empty() : Flux.fromIterable(this.sessions.values())).initialDelay(keepAliveInterval).interval(keepAliveInterval).build();
            this.keepAliveScheduler.start();
        }
    }

    @Override
    public void toHttpHandler(SolonApp app) {
        if (app != null) {
            app.get(this.sseEndpoint, this::handleSseConnection);
            app.post(this.messageEndpoint, this::handleMessage);
        }
    }

    @Override
    public String getMcpEndpoint() {
        return this.sseEndpoint;
    }

    @Override
    public List<String> protocolVersions() {
        return Utils.asList("2024-11-05");
    }

    @Override
    public void setSessionFactory(McpServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    @Override
    public Mono<Void> notifyClients(String method, Object params) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        logger.debug("Attempting to broadcast message to {} active sessions", (Object)this.sessions.size());
        return Flux.fromIterable(this.sessions.values()).flatMap(session -> session.sendNotification(method, params).doOnError(e -> logger.error("Failed to send message to session {}: {}", (Object)session.getId(), (Object)e.getMessage())).onErrorComplete()).then();
    }

    @Override
    public Mono<Void> closeGracefully() {
        return Flux.fromIterable(this.sessions.values()).doFirst(() -> {
            this.isClosing = true;
            logger.debug("Initiating graceful shutdown with {} active sessions", (Object)this.sessions.size());
        }).flatMap(McpServerSession::closeGracefully).then().doOnSuccess(v -> {
            logger.debug("Graceful shutdown completed");
            this.sessions.clear();
            if (this.keepAliveScheduler != null) {
                this.keepAliveScheduler.shutdown();
            }
        });
    }

    private void handleSseConnection(Context request) throws Throwable {
        Entity entity;
        Object returnValue = this.handleSseConnectionDo(request);
        if (returnValue instanceof Entity && (entity = (Entity)returnValue).body() != null) {
            if (entity.body() instanceof McpError) {
                McpError mcpError = (McpError)entity.body();
                entity.body((Object)mcpError.getMessage());
            } else if (entity.body() instanceof McpSchema.JSONRPCResponse) {
                entity.body((Object)this.objectMapper.writeValueAsString(entity.body()));
            }
        }
        request.returnValue(returnValue);
    }

    private Object handleSseConnectionDo(Context request) {
        if (this.isClosing) {
            return new Entity().status(503).body((Object)"Server is shutting down");
        }
        String sessionId = UUID.randomUUID().toString();
        logger.debug("Creating new SSE connection for session: {}", (Object)sessionId);
        try {
            SseEmitter sseBuilder = new SseEmitter(-1L);
            sseBuilder.onCompletion(() -> {
                logger.debug("SSE connection completed for session: {}", (Object)sessionId);
                this.sessions.remove(sessionId);
                this.sessionRequests.remove(sessionId);
            });
            sseBuilder.onTimeout(() -> {
                logger.debug("SSE connection timed out for session: {}", (Object)sessionId);
                this.sessions.remove(sessionId);
                this.sessionRequests.remove(sessionId);
            });
            sseBuilder.onInited(emitter -> {
                WebRxMcpSessionTransport sessionTransport = new WebRxMcpSessionTransport(sessionId, sseBuilder);
                McpServerSession session = this.sessionFactory.create(sessionTransport);
                this.sessions.put(sessionId, session);
                this.sessionRequests.put(sessionId, request);
                try {
                    sseBuilder.send(new SseEvent().id(sessionId).name(ENDPOINT_EVENT_TYPE).data((Object)(this.messageEndpointFull + "?sessionId=" + sessionId)));
                }
                catch (Exception e) {
                    logger.error("Failed to send initial endpoint event: {}", (Object)e.getMessage());
                    sseBuilder.error((Throwable)e);
                }
            });
            return sseBuilder;
        }
        catch (Exception e) {
            logger.error("Failed to send initial endpoint event to session {}: {}", (Object)sessionId, (Object)e.getMessage());
            this.sessions.remove(sessionId);
            this.sessionRequests.remove(sessionId);
            return new Entity().status(500);
        }
    }

    private void handleMessage(Context request) throws Throwable {
        Entity entity = this.handleMessageDo(request);
        if (entity.body() != null) {
            if (entity.body() instanceof McpError) {
                McpError mcpError = (McpError)entity.body();
                entity.body((Object)mcpError.getMessage());
            } else if (entity.body() instanceof McpSchema.JSONRPCResponse) {
                entity.body((Object)this.objectMapper.writeValueAsString(entity.body()));
            }
        }
        request.returnValue((Object)entity);
    }

    private Entity handleMessageDo(Context request) {
        if (this.isClosing) {
            return new Entity().status(503).body((Object)"Server is shutting down");
        }
        if (request.param("sessionId").isEmpty()) {
            return new Entity().status(400).body((Object)new McpError((Object)"Session ID missing in message endpoint"));
        }
        String sessionId = request.param("sessionId");
        McpServerSession session = this.sessions.get(sessionId);
        if (session == null) {
            return new Entity().status(404).body((Object)new McpError((Object)("Session not found: " + sessionId)));
        }
        Context sessionRequest = this.sessionRequests.get(sessionId);
        McpTransportContext transportContext = this.contextExtractor.extract(sessionRequest, new DefaultMcpTransportContext());
        try {
            String body = request.body();
            McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, body);
            session.handle(message).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext)).block();
            return new Entity().status(200);
        }
        catch (IOException | IllegalArgumentException e) {
            logger.error("Failed to deserialize message: {}", (Object)e.getMessage());
            return new Entity().status(400).body((Object)new McpError((Object)"Invalid message format"));
        }
        catch (Exception e) {
            logger.error("Error handling message: {}", (Object)e.getMessage());
            return new Entity().status(500).body((Object)new McpError((Object)e.getMessage()));
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private ObjectMapper objectMapper = new ObjectMapper();
        private String baseUrl = "";
        private String messageEndpoint;
        private String sseEndpoint = "/sse";
        private McpTransportContextExtractor<Context> contextExtractor = (serverRequest, context) -> {
            context.put(Context.class.getName(), serverRequest);
            return context;
        };
        private Duration keepAliveInterval;

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull(objectMapper, "ObjectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public Builder baseUrl(String baseUrl) {
            if (baseUrl != null) {
                this.baseUrl = baseUrl;
            }
            return this;
        }

        public Builder messageEndpoint(String messageEndpoint) {
            Assert.hasText(messageEndpoint, "Message endpoint must not be empty");
            this.messageEndpoint = messageEndpoint;
            return this;
        }

        public Builder sseEndpoint(String sseEndpoint) {
            Assert.hasText(sseEndpoint, "SSE endpoint must not be empty");
            this.sseEndpoint = sseEndpoint;
            return this;
        }

        public Builder contextExtractor(McpTransportContextExtractor<Context> contextExtractor) {
            Assert.notNull(contextExtractor, "contextExtractor must not be null");
            this.contextExtractor = contextExtractor;
            return this;
        }

        public Builder keepAliveInterval(Duration keepAliveInterval) {
            this.keepAliveInterval = keepAliveInterval;
            return this;
        }

        public WebRxSseServerTransportProvider build() {
            if (this.messageEndpoint == null) {
                throw new IllegalStateException("MessageEndpoint must be set");
            }
            return new WebRxSseServerTransportProvider(this.objectMapper, this.baseUrl, this.messageEndpoint, this.sseEndpoint, this.contextExtractor, this.keepAliveInterval);
        }
    }

    private class WebRxMcpSessionTransport
    implements McpServerTransport {
        private final String sessionId;
        private final SseEmitter sseBuilder;
        private final ReentrantLock sseBuilderLock = new ReentrantLock();

        WebRxMcpSessionTransport(String sessionId, SseEmitter sseBuilder) {
            this.sessionId = sessionId;
            this.sseBuilder = sseBuilder;
            logger.debug("Session transport {} initialized with SSE builder", (Object)sessionId);
        }

        @Override
        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            return Mono.fromRunnable(() -> {
                this.sseBuilderLock.lock();
                try {
                    String jsonText = WebRxSseServerTransportProvider.this.objectMapper.writeValueAsString((Object)message);
                    this.sseBuilder.send(new SseEvent().id(this.sessionId).name(WebRxSseServerTransportProvider.MESSAGE_EVENT_TYPE).data((Object)jsonText));
                    logger.debug("Message sent to session {}: {}", (Object)this.sessionId, (Object)jsonText);
                }
                catch (Exception e) {
                    logger.error("Failed to send message to session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
                    this.sseBuilder.error((Throwable)e);
                }
                finally {
                    this.sseBuilderLock.unlock();
                }
            });
        }

        @Override
        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
            return (T)WebRxSseServerTransportProvider.this.objectMapper.convertValue(data, typeRef);
        }

        @Override
        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> {
                logger.debug("Closing session transport: {}", (Object)this.sessionId);
                this.sseBuilderLock.lock();
                try {
                    this.sseBuilder.complete();
                    logger.debug("Successfully completed SSE builder for session {}", (Object)this.sessionId);
                }
                catch (Exception e) {
                    logger.warn("Failed to complete SSE builder for session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
                }
                finally {
                    this.sseBuilderLock.unlock();
                }
            });
        }

        @Override
        public void close() {
            this.sseBuilderLock.lock();
            try {
                this.sseBuilder.complete();
                logger.debug("Successfully completed SSE builder for session {}", (Object)this.sessionId);
            }
            catch (Exception e) {
                logger.warn("Failed to complete SSE builder for session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
            }
            finally {
                this.sseBuilderLock.unlock();
            }
        }
    }
}

