/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.test.context.jdbc;

import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;
import javax.sql.DataSource;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.aot.hint.ResourceHints;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator;
import org.springframework.lang.Nullable;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.TestContextAnnotationUtils;
import org.springframework.test.context.aot.AotTestExecutionListener;
import org.springframework.test.context.jdbc.MergedSqlConfig;
import org.springframework.test.context.jdbc.Sql;
import org.springframework.test.context.jdbc.SqlConfig;
import org.springframework.test.context.jdbc.SqlGroup;
import org.springframework.test.context.jdbc.SqlMergeMode;
import org.springframework.test.context.support.AbstractTestExecutionListener;
import org.springframework.test.context.transaction.TestContextTransactionUtils;
import org.springframework.test.context.util.TestContextResourceUtils;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.interceptor.DefaultTransactionAttribute;
import org.springframework.transaction.interceptor.TransactionAttribute;
import org.springframework.transaction.support.TransactionSynchronizationUtils;
import org.springframework.transaction.support.TransactionTemplate;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;

public class SqlScriptsTestExecutionListener
extends AbstractTestExecutionListener
implements AotTestExecutionListener {
    private static final String SLASH = "/";
    private static final Log logger = LogFactory.getLog(SqlScriptsTestExecutionListener.class);
    private static final ReflectionUtils.MethodFilter sqlMethodFilter = ReflectionUtils.USER_DECLARED_METHODS.and(method -> AnnotatedElementUtils.hasAnnotation((AnnotatedElement)method, Sql.class));

    @Override
    public final int getOrder() {
        return 5000;
    }

    @Override
    public void beforeTestClass(TestContext testContext) throws Exception {
        this.executeClassLevelSqlScripts(testContext, Sql.ExecutionPhase.BEFORE_TEST_CLASS);
    }

    @Override
    public void afterTestClass(TestContext testContext) throws Exception {
        this.executeClassLevelSqlScripts(testContext, Sql.ExecutionPhase.AFTER_TEST_CLASS);
    }

    @Override
    public void beforeTestMethod(TestContext testContext) {
        this.executeSqlScripts(testContext, Sql.ExecutionPhase.BEFORE_TEST_METHOD);
    }

    @Override
    public void afterTestMethod(TestContext testContext) {
        this.executeSqlScripts(testContext, Sql.ExecutionPhase.AFTER_TEST_METHOD);
    }

    @Override
    public void processAheadOfTime(RuntimeHints runtimeHints, Class<?> testClass, ClassLoader classLoader) {
        this.getSqlAnnotationsFor(testClass).forEach(sql -> this.registerClasspathResources(this.getScripts((Sql)sql, testClass, null, true), runtimeHints, classLoader));
        this.getSqlMethods(testClass).forEach(testMethod -> this.getSqlAnnotationsFor((Method)testMethod).forEach(sql -> this.registerClasspathResources(this.getScripts((Sql)sql, testClass, (Method)testMethod, false), runtimeHints, classLoader)));
    }

    private void executeClassLevelSqlScripts(TestContext testContext, Sql.ExecutionPhase executionPhase) {
        Class<?> testClass = testContext.getTestClass();
        this.executeSqlScripts(this.getSqlAnnotationsFor(testClass), testContext, executionPhase, true);
    }

    private void executeSqlScripts(TestContext testContext, Sql.ExecutionPhase executionPhase) {
        Method testMethod = testContext.getTestMethod();
        Class<?> testClass = testContext.getTestClass();
        if (this.mergeSqlAnnotations(testContext)) {
            this.executeSqlScripts(this.getSqlAnnotationsFor(testClass), testContext, executionPhase, true);
            this.executeSqlScripts(this.getSqlAnnotationsFor(testMethod), testContext, executionPhase, false);
        } else {
            Set<Sql> methodLevelSqlAnnotations = this.getSqlAnnotationsFor(testMethod);
            if (!methodLevelSqlAnnotations.isEmpty()) {
                this.executeSqlScripts(methodLevelSqlAnnotations, testContext, executionPhase, false);
            } else {
                this.executeSqlScripts(this.getSqlAnnotationsFor(testClass), testContext, executionPhase, true);
            }
        }
    }

    private boolean mergeSqlAnnotations(TestContext testContext) {
        SqlMergeMode sqlMergeMode = this.getSqlMergeModeFor(testContext.getTestMethod());
        if (sqlMergeMode == null) {
            sqlMergeMode = this.getSqlMergeModeFor(testContext.getTestClass());
        }
        return sqlMergeMode != null && sqlMergeMode.value() == SqlMergeMode.MergeMode.MERGE;
    }

    @Nullable
    private SqlMergeMode getSqlMergeModeFor(Class<?> clazz) {
        return TestContextAnnotationUtils.findMergedAnnotation(clazz, SqlMergeMode.class);
    }

    @Nullable
    private SqlMergeMode getSqlMergeModeFor(Method method) {
        return (SqlMergeMode)AnnotatedElementUtils.findMergedAnnotation((AnnotatedElement)method, SqlMergeMode.class);
    }

    private Set<Sql> getSqlAnnotationsFor(Class<?> clazz) {
        return TestContextAnnotationUtils.getMergedRepeatableAnnotations(clazz, Sql.class);
    }

    private Set<Sql> getSqlAnnotationsFor(Method method) {
        return AnnotatedElementUtils.getMergedRepeatableAnnotations((AnnotatedElement)method, Sql.class, SqlGroup.class);
    }

    private void executeSqlScripts(Set<Sql> sqlAnnotations, TestContext testContext, Sql.ExecutionPhase executionPhase, boolean classLevel) {
        sqlAnnotations.forEach(sql -> this.executeSqlScripts((Sql)sql, executionPhase, testContext, classLevel));
    }

    private void executeSqlScripts(Sql sql, Sql.ExecutionPhase executionPhase, TestContext testContext, boolean classLevel) {
        boolean newTxRequired;
        Assert.isTrue((classLevel || SqlScriptsTestExecutionListener.isValidMethodLevelPhase(sql.executionPhase()) ? 1 : 0) != 0, () -> "@SQL execution phase %s cannot be used on methods".formatted(new Object[]{sql.executionPhase()}));
        if (executionPhase != sql.executionPhase()) {
            return;
        }
        MergedSqlConfig mergedSqlConfig = new MergedSqlConfig(sql.config(), testContext.getTestClass());
        if (logger.isTraceEnabled()) {
            logger.trace((Object)"Processing %s for execution phase [%s] and test context %s".formatted(new Object[]{mergedSqlConfig, executionPhase, testContext}));
        } else if (logger.isDebugEnabled()) {
            logger.debug((Object)"Processing merged @SqlConfig attributes for execution phase [%s] and test class [%s]".formatted(new Object[]{executionPhase, testContext.getTestClass().getName()}));
        }
        boolean methodLevel = !classLevel;
        Method testMethod = methodLevel ? testContext.getTestMethod() : null;
        String[] scripts = this.getScripts(sql, testContext.getTestClass(), testMethod, classLevel);
        ApplicationContext applicationContext = testContext.getApplicationContext();
        List<Resource> scriptResources = TestContextResourceUtils.convertToResourceList((ResourceLoader)applicationContext, applicationContext.getEnvironment(), scripts);
        for (String stmt : sql.statements()) {
            if (!StringUtils.hasText((String)stmt)) continue;
            stmt = stmt.trim();
            scriptResources.add((Resource)new ByteArrayResource(stmt.getBytes(), "from inlined SQL statement: " + stmt));
        }
        ResourceDatabasePopulator populator = this.createDatabasePopulator(mergedSqlConfig);
        populator.setScripts(scriptResources.toArray(new Resource[0]));
        if (logger.isDebugEnabled()) {
            logger.debug((Object)("Executing SQL scripts: " + scriptResources));
        }
        String dsName = mergedSqlConfig.getDataSource();
        String tmName = mergedSqlConfig.getTransactionManager();
        DataSource dataSource = TestContextTransactionUtils.retrieveDataSource(testContext, dsName);
        PlatformTransactionManager txMgr = TestContextTransactionUtils.retrieveTransactionManager(testContext, tmName);
        boolean bl = newTxRequired = mergedSqlConfig.getTransactionMode() == SqlConfig.TransactionMode.ISOLATED;
        if (txMgr == null) {
            Assert.state((!newTxRequired ? 1 : 0) != 0, () -> String.format("Failed to execute SQL scripts for test context %s: cannot execute SQL scripts using Transaction Mode [%s] without a PlatformTransactionManager.", new Object[]{testContext, SqlConfig.TransactionMode.ISOLATED}));
            Assert.state((dataSource != null ? 1 : 0) != 0, () -> String.format("Failed to execute SQL scripts for test context %s: supply at least a DataSource or PlatformTransactionManager.", testContext));
            populator.execute(dataSource);
        } else {
            DataSource dataSourceFromTxMgr = this.getDataSourceFromTransactionManager(txMgr);
            if (dataSource != null && dataSourceFromTxMgr != null && !SqlScriptsTestExecutionListener.sameDataSource(dataSource, dataSourceFromTxMgr)) {
                throw new IllegalStateException(String.format("Failed to execute SQL scripts for test context %s: the configured DataSource [%s] (named '%s') is not the one associated with transaction manager [%s] (named '%s').", testContext, dataSource.getClass().getName(), dsName, txMgr.getClass().getName(), tmName));
            }
            if (dataSource == null) {
                dataSource = dataSourceFromTxMgr;
                Assert.state((dataSource != null ? 1 : 0) != 0, () -> String.format("Failed to execute SQL scripts for test context %s: could not obtain DataSource from transaction manager [%s] (named '%s').", testContext, txMgr.getClass().getName(), tmName));
            }
            DataSource finalDataSource = dataSource;
            int propagation = newTxRequired ? 3 : 0;
            TransactionAttribute txAttr = TestContextTransactionUtils.createDelegatingTransactionAttribute(testContext, (TransactionAttribute)new DefaultTransactionAttribute(propagation), methodLevel);
            new TransactionTemplate(txMgr, (TransactionDefinition)txAttr).executeWithoutResult(s -> populator.execute(finalDataSource));
        }
    }

    private ResourceDatabasePopulator createDatabasePopulator(MergedSqlConfig mergedSqlConfig) {
        ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
        populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
        populator.setSeparator(mergedSqlConfig.getSeparator());
        populator.setCommentPrefixes(mergedSqlConfig.getCommentPrefixes());
        populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
        populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
        populator.setContinueOnError(mergedSqlConfig.getErrorMode() == SqlConfig.ErrorMode.CONTINUE_ON_ERROR);
        populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == SqlConfig.ErrorMode.IGNORE_FAILED_DROPS);
        return populator;
    }

    private static boolean sameDataSource(DataSource ds1, DataSource ds2) {
        return TransactionSynchronizationUtils.unwrapResourceIfNecessary((Object)ds1).equals(TransactionSynchronizationUtils.unwrapResourceIfNecessary((Object)ds2));
    }

    @Nullable
    private DataSource getDataSourceFromTransactionManager(PlatformTransactionManager transactionManager) {
        try {
            Method getDataSourceMethod = transactionManager.getClass().getMethod("getDataSource", new Class[0]);
            Object obj = ReflectionUtils.invokeMethod((Method)getDataSourceMethod, (Object)transactionManager);
            if (obj instanceof DataSource) {
                DataSource dataSource = (DataSource)obj;
                return dataSource;
            }
        }
        catch (Exception exception) {
            // empty catch block
        }
        return null;
    }

    private String[] getScripts(Sql sql, Class<?> testClass, @Nullable Method testMethod, boolean classLevel) {
        Object[] scripts = sql.scripts();
        if (ObjectUtils.isEmpty((Object[])scripts) && ObjectUtils.isEmpty((Object[])sql.statements())) {
            scripts = new String[]{this.detectDefaultScript(testClass, testMethod, classLevel)};
        }
        return TestContextResourceUtils.convertToClasspathResourcePaths(testClass, (String[])scripts);
    }

    private String detectDefaultScript(Class<?> testClass, @Nullable Method testMethod, boolean classLevel) {
        Assert.state((classLevel || testMethod != null ? 1 : 0) != 0, (String)"Method-level @Sql requires a testMethod");
        String elementType = classLevel ? "class" : "method";
        String elementName = classLevel ? testClass.getName() : testMethod.toString();
        Object resourcePath = ClassUtils.convertClassNameToResourcePath((String)testClass.getName());
        if (!classLevel) {
            resourcePath = (String)resourcePath + "." + testMethod.getName();
        }
        resourcePath = (String)resourcePath + ".sql";
        String prefixedResourcePath = "classpath:/" + (String)resourcePath;
        ClassPathResource classPathResource = new ClassPathResource((String)resourcePath);
        if (classPathResource.exists()) {
            if (logger.isDebugEnabled()) {
                logger.debug((Object)"Detected default SQL script \"%s\" for test %s [%s]".formatted(prefixedResourcePath, elementType, elementName));
            }
            return prefixedResourcePath;
        }
        String msg = String.format("Could not detect default SQL script for test %s [%s]: %s does not exist. Either declare statements or scripts via @Sql or make the default SQL script available.", elementType, elementName, classPathResource);
        logger.error((Object)msg);
        throw new IllegalStateException(msg);
    }

    private Stream<Method> getSqlMethods(Class<?> testClass) {
        return Arrays.stream(ReflectionUtils.getUniqueDeclaredMethods(testClass, (ReflectionUtils.MethodFilter)sqlMethodFilter));
    }

    private void registerClasspathResources(String[] paths, RuntimeHints runtimeHints, ClassLoader classLoader) {
        DefaultResourceLoader resourceLoader = new DefaultResourceLoader(classLoader);
        Arrays.stream(paths).filter(path -> path.startsWith("classpath:")).map(arg_0 -> ((DefaultResourceLoader)resourceLoader).getResource(arg_0)).forEach(arg_0 -> ((ResourceHints)runtimeHints.resources()).registerResource(arg_0));
    }

    private static boolean isValidMethodLevelPhase(Sql.ExecutionPhase executionPhase) {
        return executionPhase == Sql.ExecutionPhase.BEFORE_TEST_METHOD || executionPhase == Sql.ExecutionPhase.AFTER_TEST_METHOD;
    }
}

