/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.sshd.server.auth;

import java.net.MalformedURLException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import org.apache.sshd.client.SshClient;
import org.apache.sshd.client.auth.keyboard.UserInteraction;
import org.apache.sshd.client.session.ClientSession;
import org.apache.sshd.common.config.keys.KeyRandomArt;
import org.apache.sshd.common.keyprovider.KeyPairProvider;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.core.CoreModuleProperties;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.util.test.BaseTestSupport;
import org.apache.sshd.util.test.CoreTestSupportUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.MethodOrderer.MethodName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder;

/**
 * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
 */
@TestMethodOrder(MethodName.class)
class WelcomeBannerTest extends BaseTestSupport {

    private static SshServer sshd;
    private static int port;
    private static SshClient client;

    WelcomeBannerTest() {
        super();
    }

    @BeforeAll
    static void setupClientAndServer() throws Exception {
        sshd = CoreTestSupportUtils.setupTestServer(WelcomeBannerTest.class);
        sshd.start();
        port = sshd.getPort();

        client = CoreTestSupportUtils.setupTestClient(WelcomeBannerTest.class);
        client.start();
    }

    @AfterAll
    static void tearDownClientAndServer() throws Exception {
        if (sshd != null) {
            try {
                sshd.stop(true);
            } finally {
                sshd = null;
            }
        }

        if (client != null) {
            try {
                client.stop();
            } finally {
                client = null;
            }
        }
    }

    @Test
    void simpleBanner() throws Exception {
        String expectedWelcome = "Welcome to SSHD WelcomeBannerTest";
        CoreModuleProperties.WELCOME_BANNER.set(sshd, expectedWelcome);
        testBanner(expectedWelcome);
    }

    // see SSHD-686
    @Test
    void autoGeneratedBanner() throws Exception {
        KeyPairProvider keys = sshd.getKeyPairProvider();
        CoreModuleProperties.WELCOME_BANNER.set(sshd, CoreModuleProperties.AUTO_WELCOME_BANNER_VALUE);
        testBanner(KeyRandomArt.combine(null, ' ', keys));
    }

    @Test
    void pathBanner() throws Exception {
        testFileContentBanner(Function.identity());
    }

    @Test
    void fileBanner() throws Exception {
        testFileContentBanner(Path::toFile);
    }

    @Test
    void uriBanner() throws Exception {
        testFileContentBanner(Path::toUri);
    }

    @Test
    void uriStringBanner() throws Exception {
        testFileContentBanner(path -> Objects.toString(path.toUri()));
    }

    @Test
    void urlBanner() throws Exception {
        testFileContentBanner(path -> {
            try {
                return path.toUri().toURL();
            } catch (MalformedURLException e) {
                throw new RuntimeException(e);
            }
        });
    }

    @Test
    void fileNotExistsBanner() throws Exception {
        Path dir = getTempTargetRelativeFile(getClass().getSimpleName());
        Path file = assertHierarchyTargetFolderExists(dir).resolve(getCurrentTestName() + ".txt");
        Files.deleteIfExists(file);
        assertFalse(Files.exists(file), "Banner file not deleted: " + file);
        CoreModuleProperties.WELCOME_BANNER.set(sshd, file);
        testBanner(null);
    }

    @Test
    void emptyFileBanner() throws Exception {
        Path dir = getTempTargetRelativeFile(getClass().getSimpleName());
        Path file = assertHierarchyTargetFolderExists(dir).resolve(getCurrentTestName() + ".txt");
        Files.deleteIfExists(file);
        Files.write(file, GenericUtils.EMPTY_BYTE_ARRAY);
        assertTrue(Files.exists(file), "Empty file not created: " + file);
        CoreModuleProperties.WELCOME_BANNER.set(sshd, file);
        testBanner(null);
    }

    // see SSHD-695
    @Test
    void welcomeBannerBeforeAuthBegins() throws Exception {
        UserInteraction ui = client.getUserInteraction();
        try {
            Semaphore sigSem = new Semaphore(0);
            client.setUserInteraction(new UserInteraction() {
                @Override
                public void welcome(ClientSession session, String banner, String lang) {
                    sigSem.release();
                }

                @Override
                public boolean isInteractionAllowed(ClientSession session) {
                    return true;
                }

                @Override
                public String[] interactive(
                        ClientSession session, String name, String instruction,
                        String lang, String[] prompt, boolean[] echo) {
                    throw new UnsupportedOperationException("Unexpected interactive call");
                }

                @Override
                public String getUpdatedPassword(ClientSession session, String prompt, String lang) {
                    throw new UnsupportedOperationException("Unexpected password update call");
                }
            });
            CoreModuleProperties.WELCOME_BANNER.set(sshd, getCurrentTestName());

            try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
                    .verify(CONNECT_TIMEOUT)
                    .getSession()) {
                assertTrue(sigSem.tryAcquire(DEFAULT_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS),
                        "Welcome not signalled on time");
                session.addPasswordIdentity(getCurrentTestName());
                session.auth().verify(AUTH_TIMEOUT);
            }

        } finally {
            client.setUserInteraction(ui);
        }
    }

    private void testFileContentBanner(Function<? super Path, ?> configValueExtractor) throws Exception {
        Path dir = getTempTargetRelativeFile(getClass().getSimpleName());
        Path file = assertHierarchyTargetFolderExists(dir).resolve(getCurrentTestName() + ".txt");
        String expectedWelcome = getClass().getName() + "#" + getCurrentTestName();
        Files.deleteIfExists(file);
        Files.write(file, expectedWelcome.getBytes(StandardCharsets.UTF_8));
        Object configValue = configValueExtractor.apply(file);
        CoreModuleProperties.WELCOME_BANNER.set(sshd, configValue);
        testBanner(expectedWelcome);
    }

    private void testBanner(String expectedWelcome) throws Exception {
        UserInteraction ui = client.getUserInteraction();
        AtomicReference<String> welcomeHolder = new AtomicReference<>(null);
        try {
            AtomicReference<ClientSession> sessionHolder = new AtomicReference<>(null);
            client.setUserInteraction(new UserInteraction() {
                @Override
                public boolean isInteractionAllowed(ClientSession session) {
                    return true;
                }

                @Override
                public void serverVersionInfo(ClientSession session, List<String> lines) {
                    validateSession("serverVersionInfo", session);
                }

                @Override
                public void welcome(ClientSession session, String banner, String lang) {
                    validateSession("welcome", session);
                    assertNull(welcomeHolder.getAndSet(banner), "Multiple banner invocations");
                }

                @Override
                public String[] interactive(
                        ClientSession session, String name, String instruction,
                        String lang, String[] prompt, boolean[] echo) {
                    validateSession("interactive", session);
                    return null;
                }

                @Override
                public String getUpdatedPassword(ClientSession clientSession, String prompt, String lang) {
                    throw new UnsupportedOperationException("Unexpected call");
                }

                private void validateSession(String phase, ClientSession session) {
                    ClientSession prev = sessionHolder.getAndSet(session);
                    if (prev != null) {
                        assertSame(prev, session, "Mismatched " + phase + " client session");
                    }
                }
            });

            try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
                    .verify(CONNECT_TIMEOUT).getSession()) {
                session.addPasswordIdentity(getCurrentTestName());
                session.auth().verify(AUTH_TIMEOUT);
                if (expectedWelcome != null) {
                    assertSame(session, sessionHolder.get(), "Mismatched sessions");
                } else {
                    assertNull(sessionHolder.get(), "Unexpected session");
                }
            }
        } finally {
            client.setUserInteraction(ui);
        }
        assertEquals(expectedWelcome, welcomeHolder.get(), "Mismatched banner");
    }
}
