/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.common.model;

import java.io.IOException;
import java.security.AccessController;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.common.model.Guardrail;
import org.opensearch.ml.common.model.StopWords;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.transport.client.Client;

public class LocalRegexGuardrail
extends Guardrail {
    @Generated
    private static final Logger log = LogManager.getLogger(LocalRegexGuardrail.class);
    public static final String STOP_WORDS_FIELD = "stop_words";
    public static final String REGEX_FIELD = "regex";
    private List<StopWords> stopWords;
    private String[] regex;
    private List<Pattern> regexPattern;
    private Map<String, List<String>> stopWordsIndicesInput;
    private NamedXContentRegistry xContentRegistry;
    private Client client;
    private SdkClient sdkClient;
    private String tenantId;

    public LocalRegexGuardrail(List<StopWords> stopWords, String[] regex) {
        this.stopWords = stopWords;
        this.regex = regex;
    }

    public LocalRegexGuardrail(@NonNull Map<String, Object> params) {
        List regexes;
        Objects.requireNonNull(params, "params is marked non-null but is null");
        List words = (List)params.get(STOP_WORDS_FIELD);
        this.stopWords = new ArrayList<StopWords>();
        if (words != null && !words.isEmpty()) {
            for (Map e : words) {
                this.stopWords.add(new StopWords(e));
            }
        }
        if ((regexes = (List)params.get(REGEX_FIELD)) != null && !regexes.isEmpty()) {
            this.regex = regexes.toArray(new String[0]);
        }
    }

    public LocalRegexGuardrail(StreamInput input) throws IOException {
        if (input.readBoolean()) {
            this.stopWords = new ArrayList<StopWords>();
            int size = input.readInt();
            for (int i = 0; i < size; ++i) {
                this.stopWords.add(new StopWords(input));
            }
        }
        this.regex = input.readOptionalStringArray();
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        if (this.stopWords != null && this.stopWords.size() > 0) {
            out.writeBoolean(true);
            out.writeInt(this.stopWords.size());
            for (StopWords e : this.stopWords) {
                e.writeTo(out);
            }
        } else {
            out.writeBoolean(false);
        }
        out.writeOptionalStringArray(this.regex);
    }

    @Override
    public Boolean validate(String input, Map<String, String> parameters) {
        return this.validateRegexList(input, this.regexPattern) != false && this.validateStopWords(input, this.stopWordsIndicesInput) != false;
    }

    @Override
    public void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) {
        this.xContentRegistry = xContentRegistry;
        this.client = client;
        this.sdkClient = sdkClient;
        this.tenantId = tenantId;
        this.init();
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        if (this.stopWords != null && this.stopWords.size() > 0) {
            builder.field(STOP_WORDS_FIELD, this.stopWords);
        }
        if (this.regex != null) {
            builder.field(REGEX_FIELD, (Object)this.regex);
        }
        builder.endObject();
        return builder;
    }

    public static LocalRegexGuardrail parse(XContentParser parser) throws IOException {
        ArrayList<StopWords> stopWords = null;
        String[] regex = null;
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
        block8: while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();
            switch (fieldName) {
                case "stop_words": {
                    stopWords = new ArrayList<StopWords>();
                    XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_ARRAY, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
                    while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
                        stopWords.add(StopWords.parse(parser));
                    }
                    continue block8;
                }
                case "regex": {
                    regex = parser.list().toArray(new String[0]);
                    continue block8;
                }
            }
            parser.skipChildren();
        }
        return LocalRegexGuardrail.builder().stopWords(stopWords).regex(regex).build();
    }

    private void init() {
        this.stopWordsIndicesInput = this.stopWordsToMap();
        List<Object> regexList = this.regex == null ? new ArrayList() : Arrays.asList(this.regex);
        this.regexPattern = regexList.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList());
    }

    private Map<String, List<String>> stopWordsToMap() {
        HashMap<String, List<String>> map = new HashMap<String, List<String>>();
        if (this.stopWords != null && !this.stopWords.isEmpty()) {
            for (StopWords e : this.stopWords) {
                if (e.getIndex() == null || e.getSourceFields() == null) continue;
                map.put(e.getIndex(), Arrays.asList(e.getSourceFields()));
            }
        }
        return map;
    }

    public Boolean validateRegexList(String input, List<Pattern> regexPatterns) {
        if (regexPatterns == null || regexPatterns.isEmpty()) {
            return true;
        }
        for (Pattern pattern : regexPatterns) {
            if (this.validateRegex(input, pattern).booleanValue()) continue;
            return false;
        }
        return true;
    }

    public Boolean validateRegex(String input, Pattern pattern) {
        Matcher matcher = pattern.matcher(input);
        return !matcher.matches();
    }

    public Boolean validateStopWords(String input, Map<String, List<String>> stopWordsIndices) {
        if (stopWordsIndices == null || stopWordsIndices.isEmpty()) {
            return true;
        }
        for (Map.Entry<String, List<String>> entry : stopWordsIndices.entrySet()) {
            if (this.validateStopWordsSingleIndex(input, entry.getKey(), entry.getValue()).booleanValue()) continue;
            return false;
        }
        return true;
    }

    public Boolean validateStopWordsSingleIndex(String input, String indexName, List<String> fieldNames) {
        AtomicBoolean passedStopWordCheck = new AtomicBoolean(false);
        HashMap<String, String> documentMap = new HashMap<String, String>();
        for (String field : fieldNames) {
            documentMap.put(field, input);
        }
        Map queryBodyMap = Map.of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap)));
        CountDownLatch latch = new CountDownLatch(1);
        try {
            String queryBody = AccessController.doPrivileged(() -> StringUtils.gson.toJson((Object)queryBodyMap));
            SearchDataObjectRequest searchDataObjectRequest = this.buildSearchDataObjectRequest(indexName, queryBody);
            LatchedActionListener responseListener = new LatchedActionListener(ActionListener.wrap(r -> {
                if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0L) {
                    passedStopWordCheck.set(true);
                }
            }, e -> {
                log.error("Failed to search stop words index {}", (Object)indexName, e);
                passedStopWordCheck.set(true);
            }), latch);
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete(SdkClientUtils.wrapSearchCompletion((ActionListener)ActionListener.runBefore((ActionListener)responseListener, () -> ((ThreadContext.StoredContext)context).restore()), (Class[])new Class[0]));
            }
        }
        catch (Exception e2) {
            log.error("[validateStopWords] Searching stop words index failed.", (Throwable)e2);
            latch.countDown();
            passedStopWordCheck.set(true);
        }
        try {
            latch.await(5L, TimeUnit.SECONDS);
        }
        catch (InterruptedException e3) {
            log.error("[validateStopWords] Searching stop words index was timeout.", (Throwable)e3);
            throw new IllegalStateException(e3);
        }
        return passedStopWordCheck.get();
    }

    protected SearchDataObjectRequest buildSearchDataObjectRequest(String indexName, String queryBody) throws IOException {
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        XContentParser queryParser = XContentType.JSON.xContent().createParser(this.xContentRegistry, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, queryBody);
        searchSourceBuilder.parseXContent(queryParser);
        searchSourceBuilder.size(1);
        return SearchDataObjectRequest.builder().indices(new String[]{indexName}).searchSourceBuilder(searchSourceBuilder).tenantId(this.tenantId).build();
    }

    @Generated
    public static LocalRegexGuardrailBuilder builder() {
        return new LocalRegexGuardrailBuilder();
    }

    @Generated
    public LocalRegexGuardrailBuilder toBuilder() {
        return new LocalRegexGuardrailBuilder().stopWords(this.stopWords).regex(this.regex);
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LocalRegexGuardrail)) {
            return false;
        }
        LocalRegexGuardrail other = (LocalRegexGuardrail)o;
        if (!other.canEqual(this)) {
            return false;
        }
        List<StopWords> this$stopWords = this.getStopWords();
        List<StopWords> other$stopWords = other.getStopWords();
        if (this$stopWords == null ? other$stopWords != null : !((Object)this$stopWords).equals(other$stopWords)) {
            return false;
        }
        if (!Arrays.deepEquals(this.getRegex(), other.getRegex())) {
            return false;
        }
        List<Pattern> this$regexPattern = this.getRegexPattern();
        List<Pattern> other$regexPattern = other.getRegexPattern();
        if (this$regexPattern == null ? other$regexPattern != null : !((Object)this$regexPattern).equals(other$regexPattern)) {
            return false;
        }
        Map<String, List<String>> this$stopWordsIndicesInput = this.getStopWordsIndicesInput();
        Map<String, List<String>> other$stopWordsIndicesInput = other.getStopWordsIndicesInput();
        if (this$stopWordsIndicesInput == null ? other$stopWordsIndicesInput != null : !((Object)this$stopWordsIndicesInput).equals(other$stopWordsIndicesInput)) {
            return false;
        }
        NamedXContentRegistry this$xContentRegistry = this.getXContentRegistry();
        NamedXContentRegistry other$xContentRegistry = other.getXContentRegistry();
        if (this$xContentRegistry == null ? other$xContentRegistry != null : !this$xContentRegistry.equals(other$xContentRegistry)) {
            return false;
        }
        Client this$client = this.getClient();
        Client other$client = other.getClient();
        if (this$client == null ? other$client != null : !this$client.equals(other$client)) {
            return false;
        }
        SdkClient this$sdkClient = this.getSdkClient();
        SdkClient other$sdkClient = other.getSdkClient();
        if (this$sdkClient == null ? other$sdkClient != null : !this$sdkClient.equals(other$sdkClient)) {
            return false;
        }
        String this$tenantId = this.getTenantId();
        String other$tenantId = other.getTenantId();
        return !(this$tenantId == null ? other$tenantId != null : !this$tenantId.equals(other$tenantId));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof LocalRegexGuardrail;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        List<StopWords> $stopWords = this.getStopWords();
        result = result * 59 + ($stopWords == null ? 43 : ((Object)$stopWords).hashCode());
        result = result * 59 + Arrays.deepHashCode(this.getRegex());
        List<Pattern> $regexPattern = this.getRegexPattern();
        result = result * 59 + ($regexPattern == null ? 43 : ((Object)$regexPattern).hashCode());
        Map<String, List<String>> $stopWordsIndicesInput = this.getStopWordsIndicesInput();
        result = result * 59 + ($stopWordsIndicesInput == null ? 43 : ((Object)$stopWordsIndicesInput).hashCode());
        NamedXContentRegistry $xContentRegistry = this.getXContentRegistry();
        result = result * 59 + ($xContentRegistry == null ? 43 : $xContentRegistry.hashCode());
        Client $client = this.getClient();
        result = result * 59 + ($client == null ? 43 : $client.hashCode());
        SdkClient $sdkClient = this.getSdkClient();
        result = result * 59 + ($sdkClient == null ? 43 : $sdkClient.hashCode());
        String $tenantId = this.getTenantId();
        result = result * 59 + ($tenantId == null ? 43 : $tenantId.hashCode());
        return result;
    }

    @Generated
    public List<StopWords> getStopWords() {
        return this.stopWords;
    }

    @Generated
    public String[] getRegex() {
        return this.regex;
    }

    @Generated
    public List<Pattern> getRegexPattern() {
        return this.regexPattern;
    }

    @Generated
    public Map<String, List<String>> getStopWordsIndicesInput() {
        return this.stopWordsIndicesInput;
    }

    @Generated
    public NamedXContentRegistry getXContentRegistry() {
        return this.xContentRegistry;
    }

    @Generated
    public Client getClient() {
        return this.client;
    }

    @Generated
    public SdkClient getSdkClient() {
        return this.sdkClient;
    }

    @Generated
    public String getTenantId() {
        return this.tenantId;
    }

    @Generated
    public static class LocalRegexGuardrailBuilder {
        @Generated
        private List<StopWords> stopWords;
        @Generated
        private String[] regex;

        @Generated
        LocalRegexGuardrailBuilder() {
        }

        @Generated
        public LocalRegexGuardrailBuilder stopWords(List<StopWords> stopWords) {
            this.stopWords = stopWords;
            return this;
        }

        @Generated
        public LocalRegexGuardrailBuilder regex(String[] regex) {
            this.regex = regex;
            return this;
        }

        @Generated
        public LocalRegexGuardrail build() {
            return new LocalRegexGuardrail(this.stopWords, this.regex);
        }

        @Generated
        public String toString() {
            return "LocalRegexGuardrail.LocalRegexGuardrailBuilder(stopWords=" + String.valueOf(this.stopWords) + ", regex=" + Arrays.deepToString(this.regex) + ")";
        }
    }
}

