/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.transport;

import java.io.IOException;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.neuralsearch.stats.common.StatSnapshot;
import org.opensearch.neuralsearch.stats.events.EventStatName;
import org.opensearch.neuralsearch.stats.events.EventStatsManager;
import org.opensearch.neuralsearch.stats.events.TimestampedEventStatSnapshot;
import org.opensearch.neuralsearch.stats.info.InfoStatName;
import org.opensearch.neuralsearch.stats.info.InfoStatsManager;
import org.opensearch.neuralsearch.stats.metrics.MemoryStatSnapshot;
import org.opensearch.neuralsearch.stats.metrics.MetricStatName;
import org.opensearch.neuralsearch.stats.metrics.MetricStatsManager;
import org.opensearch.neuralsearch.transport.NeuralStatsNodeRequest;
import org.opensearch.neuralsearch.transport.NeuralStatsNodeResponse;
import org.opensearch.neuralsearch.transport.NeuralStatsRequest;
import org.opensearch.neuralsearch.transport.NeuralStatsResponse;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class NeuralStatsTransportAction
extends TransportNodesAction<NeuralStatsRequest, NeuralStatsResponse, NeuralStatsNodeRequest, NeuralStatsNodeResponse> {
    private final EventStatsManager eventStatsManager;
    private final InfoStatsManager infoStatsManager;
    private final MetricStatsManager metricStatsManager;

    @Inject
    public NeuralStatsTransportAction(ThreadPool threadPool, ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, EventStatsManager eventStatsManager, InfoStatsManager infoStatsManager, MetricStatsManager metricStatsManager) {
        super("cluster:admin/neural_stats_action", threadPool, clusterService, transportService, actionFilters, NeuralStatsRequest::new, NeuralStatsNodeRequest::new, "management", NeuralStatsNodeResponse.class);
        this.eventStatsManager = eventStatsManager;
        this.infoStatsManager = infoStatsManager;
        this.metricStatsManager = metricStatsManager;
    }

    protected NeuralStatsResponse newResponse(NeuralStatsRequest request, List<NeuralStatsNodeResponse> responses, List<FailedNodeException> failures) {
        Map<String, Map<String, StatSnapshot<?>>> nodeIdToEventStats = this.processorNodeEventStatsIntoMap(responses);
        Map<String, StatSnapshot<?>> aggregatedNodeStats = Collections.emptyMap();
        if (request.getNeuralStatsInput().isIncludeAllNodes()) {
            aggregatedNodeStats = this.aggregateNodesResponses(responses, request.getNeuralStatsInput().getEventStatNames(), request.getNeuralStatsInput().getMetricStatNames());
        }
        Map<String, StatSnapshot<Object>> flatInfoStats = Collections.emptyMap();
        if (request.getNeuralStatsInput().isIncludeInfo()) {
            Map<InfoStatName, StatSnapshot<?>> infoStats = this.infoStatsManager.getStats(request.getNeuralStatsInput().getInfoStatNames());
            flatInfoStats = infoStats.entrySet().stream().collect(Collectors.toMap(entry -> ((InfoStatName)entry.getKey()).getFullPath(), Map.Entry::getValue));
        }
        return new NeuralStatsResponse(this.clusterService.getClusterName(), responses, failures, flatInfoStats, aggregatedNodeStats, nodeIdToEventStats, request.getNeuralStatsInput().isFlatten(), request.getNeuralStatsInput().isIncludeMetadata(), request.getNeuralStatsInput().isIncludeIndividualNodes(), request.getNeuralStatsInput().isIncludeAllNodes(), request.getNeuralStatsInput().isIncludeInfo(), request.getNeuralStatsInput().isIncludeMetrics());
    }

    protected NeuralStatsNodeRequest newNodeRequest(NeuralStatsRequest request) {
        return new NeuralStatsNodeRequest(request);
    }

    protected NeuralStatsNodeResponse newNodeResponse(StreamInput in) throws IOException {
        return new NeuralStatsNodeResponse(in);
    }

    protected NeuralStatsNodeResponse nodeOperation(NeuralStatsNodeRequest request) {
        EnumSet<EventStatName> eventStatsToRetrieve = request.getRequest().getNeuralStatsInput().getEventStatNames();
        Map<EventStatName, TimestampedEventStatSnapshot> eventStatDataMap = this.eventStatsManager.getTimestampedEventStatSnapshots(eventStatsToRetrieve);
        EnumSet<MetricStatName> metricStatsToRetrieve = request.getRequest().getNeuralStatsInput().getMetricStatNames();
        Map<MetricStatName, MemoryStatSnapshot> metricStatDataMap = this.metricStatsManager.getStats(metricStatsToRetrieve);
        return new NeuralStatsNodeResponse(this.clusterService.localNode(), eventStatDataMap, metricStatDataMap);
    }

    private Map<String, StatSnapshot<?>> aggregateNodesResponses(List<NeuralStatsNodeResponse> responses, EnumSet<EventStatName> eventStatsToRetrieve, EnumSet<MetricStatName> metricStatsToRetrieve) {
        if (responses == null || responses.isEmpty()) {
            return new HashMap();
        }
        List<Map> nodeEventStatsList = responses.stream().map(NeuralStatsNodeResponse::getEventStats).map(map -> map.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))).toList();
        HashMap aggregatedMap = new HashMap();
        for (EventStatName eventStatName : eventStatsToRetrieve) {
            HashSet<TimestampedEventStatSnapshot> timestampedEventStatSnapshotCollection = new HashSet<TimestampedEventStatSnapshot>();
            for (Map map2 : nodeEventStatsList) {
                timestampedEventStatSnapshotCollection.add((TimestampedEventStatSnapshot)map2.get(eventStatName));
            }
            TimestampedEventStatSnapshot aggregatedEventSnapshots = TimestampedEventStatSnapshot.aggregateEventStatSnapshots(timestampedEventStatSnapshotCollection);
            if (aggregatedEventSnapshots == null) continue;
            aggregatedMap.put(eventStatName.getFullPath(), aggregatedEventSnapshots);
        }
        List<Map> nodeMetricStatsList = responses.stream().map(NeuralStatsNodeResponse::getMetricStats).map(map -> map.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))).toList();
        for (MetricStatName metricStatName : metricStatsToRetrieve) {
            HashSet<MemoryStatSnapshot> memoryStatSnapshotCollection = new HashSet<MemoryStatSnapshot>();
            for (Map metricStats : nodeMetricStatsList) {
                if (!((MemoryStatSnapshot)metricStats.get(metricStatName)).isAggregationMetric()) continue;
                memoryStatSnapshotCollection.add((MemoryStatSnapshot)metricStats.get(metricStatName));
            }
            MemoryStatSnapshot memoryStatSnapshot = MemoryStatSnapshot.aggregateMetricSnapshots(memoryStatSnapshotCollection);
            if (memoryStatSnapshot == null) continue;
            aggregatedMap.put(metricStatName.getFullPath(), memoryStatSnapshot);
        }
        return aggregatedMap;
    }

    private Map<String, Map<String, StatSnapshot<?>>> processorNodeEventStatsIntoMap(List<NeuralStatsNodeResponse> nodeResponses) {
        HashMap results = new HashMap();
        for (NeuralStatsNodeResponse nodesResponse : nodeResponses) {
            String nodeId = nodesResponse.getNode().getId();
            Map<String, StatSnapshot> resultNodeStatsMap = nodesResponse.getEventStats().entrySet().stream().collect(Collectors.toMap(entry -> ((EventStatName)entry.getKey()).getFullPath(), Map.Entry::getValue));
            resultNodeStatsMap.putAll(nodesResponse.getMetricStats().entrySet().stream().collect(Collectors.toMap(entry -> ((MetricStatName)entry.getKey()).getFullPath(), Map.Entry::getValue)));
            results.put(nodeId, resultNodeStatsMap);
        }
        return results;
    }
}

