package com.yeejoin.amos.boot.module.common.biz.service.impl;

import org.apache.commons.lang.ArrayUtils;
import org.elasticsearch.action.search.ClearScrollRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.Scroll;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.springframework.stereotype.Component;
import org.typroject.tyboot.core.foundation.utils.ValidationUtil;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;

@Component
public class EsSearchServiceImpl {

    final
    RestHighLevelClient restHighLevelClient;

    /**
     * 批次大小
     */
    private static final int SIZE = 5000;

    /**
     * 滚动时间间隔
     */
    private static final long SCROLL_TIMEOUT = 5;

    public EsSearchServiceImpl(RestHighLevelClient restHighLevelClient) {
        this.restHighLevelClient = restHighLevelClient;
    }

    /**
     * 构建SearchResponse - 批量处理版本
     *
     * @param indices 索引
     * @param query   queryBuilder
     * @param batchSize 批次大小
     * @param consumer  处理每批次数据的消费者
     * @throws Exception e
     */
    public void searchResponseInBatch(String indices, QueryBuilder query, int batchSize, Consumer<List<SearchHit>> consumer) throws Exception {
        SearchRequest request = new SearchRequest(indices);
        Scroll scroll = new Scroll(TimeValue.timeValueMinutes(SCROLL_TIMEOUT));
        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
        sourceBuilder.query(query);
        int searchSize = ValidationUtil.isEmpty(batchSize) ? SIZE : batchSize;
        sourceBuilder.size(searchSize);

        request.scroll(scroll);
        request.source(sourceBuilder);

        List<String> scrollIdList = new ArrayList<>();

        SearchResponse searchResponse = restHighLevelClient.search(request, RequestOptions.DEFAULT);
        String scrollId = searchResponse.getScrollId();
        SearchHit[] hits = searchResponse.getHits().getHits();
        // 只有当scrollId不为空且不在列表中时才添加
        if (scrollId != null && !scrollIdList.contains(scrollId)) {
            scrollIdList.add(scrollId);
        }

        try {
            while (ArrayUtils.isNotEmpty(hits)) {
                // 处理当前批次数据
                consumer.accept(Arrays.asList(hits));

                if (hits.length < searchSize) {
                    break;
                }
                SearchScrollRequest searchScrollRequest = new SearchScrollRequest(scrollId);
                searchScrollRequest.scroll(scroll);
                SearchResponse searchScrollResponse = restHighLevelClient.scroll(searchScrollRequest, RequestOptions.DEFAULT);
                scrollId = searchScrollResponse.getScrollId();
                hits = searchScrollResponse.getHits().getHits();
                // 只有当scrollId不为空且不在列表中时才添加
                if (scrollId != null && !scrollIdList.contains(scrollId)) {
                    scrollIdList.add(scrollId);
                }
            }
        } finally {
            // 只有当scrollIdList不为空时才执行清理操作
            if (!scrollIdList.isEmpty()) {
                ClearScrollRequest clearScrollRequest = new ClearScrollRequest();
                clearScrollRequest.setScrollIds(scrollIdList);
                restHighLevelClient.clearScroll(clearScrollRequest, RequestOptions.DEFAULT);
            }
        }
    }


    /**
     * 构建SearchResponse
     *
     * @param indices 索引
     * @param query   queryBuilder
     * @param fun     返回函数
     * @param <T>     返回类型
     * @return List, 可以使用fun转换为T结果
     * @throws Exception e
     */
    public <T> List<T> searchResponse(String indices, QueryBuilder query, Function<SearchHit, T> fun) throws Exception {
        List<T> result = new ArrayList<>();

        // 复用批量处理方法，将结果收集到列表中
        searchResponseInBatch(indices, query, SIZE, hits -> {
            for (SearchHit hit : hits) {
                result.add(fun.apply(hit));
            }
        });

        return result;
    }
}
