/*
    Copyright (c) 2014 Martin Sustrik  All rights reserved.

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"),
    to deal in the Software without restriction, including without limitation
    the rights to use, copy, modify, merge, publish, distribute, sublicense,
    and/or sell copies of the Software, and to permit persons to whom
    the Software is furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included
    in all copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
    THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
    IN THE SOFTWARE.
*/

#include "../nn.h"

#if defined NN_HAVE_WINDOWS

#include "../utils/err.h"

int nn_tcpmuxd (int port)
{
    errno = EPROTONOSUPPORT;
    return -1;
}

#else

#include "../utils/thread.h"
#include "../utils/attr.h"
#include "../utils/err.h"
#include "../utils/int.h"
#include "../utils/cont.h"
#include "../utils/wire.h"
#include "../utils/alloc.h"
#include "../utils/list.h"
#include "../utils/mutex.h"
#include "../utils/closefd.h"

#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/time.h>
#include <sys/un.h>
#include <stddef.h>
#include <ctype.h>
#include <poll.h>

struct nn_tcpmuxd_ctx {
    int tcp_listener;
    int ipc_listener;
    struct nn_list conns;
    struct nn_thread thread;
};

struct nn_tcpmuxd_conn {
    int fd;
    char *service;
    struct nn_list_item item;
};

/*  Forward declarations. */
static void nn_tcpmuxd_routine (void *arg);
static int send_fd (int s, int fd);

int nn_tcpmuxd (int port)
{
    int rc;
    int tcp_listener;
    int ipc_listener;
    int opt;
    struct sockaddr_in tcp_addr;
    struct sockaddr_un ipc_addr;
    struct nn_tcpmuxd_ctx *ctx;

    /*  Start listening on the specified TCP port. */
    tcp_listener = socket (AF_INET, SOCK_STREAM, IPPROTO_TCP);
    errno_assert (tcp_listener >= 0);
    opt = 1;
    rc = setsockopt (tcp_listener, SOL_SOCKET, SO_REUSEADDR, &opt,
        sizeof (opt));
    errno_assert (rc == 0);
    memset (&tcp_addr, 0, sizeof (tcp_addr));
    tcp_addr.sin_family = AF_INET;
    tcp_addr.sin_port = htons (port);
    tcp_addr.sin_addr.s_addr = INADDR_ANY;
    rc = bind (tcp_listener, (struct sockaddr*) &tcp_addr, sizeof (tcp_addr));
    errno_assert (rc == 0);
    rc = listen (tcp_listener, 100);
    errno_assert (rc == 0);

    /*  Start listening for incoming IPC connections. */
    ipc_addr.sun_family = AF_UNIX;
    snprintf (ipc_addr.sun_path, sizeof (ipc_addr.sun_path),
        "/tmp/tcpmux-%d.ipc", (int) port);
    unlink (ipc_addr.sun_path);
    ipc_listener = socket (AF_UNIX, SOCK_STREAM, 0);
    errno_assert (ipc_listener >= 0);
    rc = bind (ipc_listener, (struct sockaddr*) &ipc_addr, sizeof (ipc_addr));
    errno_assert (rc == 0);
    rc = listen (ipc_listener, 100);
    errno_assert (rc == 0);

    /*  Allocate a context for the daemon. */
    ctx = nn_alloc (sizeof (struct nn_tcpmuxd_ctx), "tcpmuxd context");
    alloc_assert (ctx);
    ctx->tcp_listener = tcp_listener;
    ctx->ipc_listener = ipc_listener;
    nn_list_init (&ctx->conns);

    /*  Run the daemon in a dedicated thread. */
    nn_thread_init (&ctx->thread, nn_tcpmuxd_routine, ctx);

    return 0;
}

/*  Main body of the daemon. */
static void nn_tcpmuxd_routine (void *arg)
{
    int rc;
    struct nn_tcpmuxd_ctx *ctx;
    struct pollfd pfd [2];
    int conn;
    int pos;
    char service [256];
    struct nn_tcpmuxd_conn *tc;
    size_t sz;
    ssize_t ssz;
    int i;
    struct nn_list_item *it;
    unsigned char buf [2];
    struct timeval tv;

    ctx = (struct nn_tcpmuxd_ctx*) arg;

    pfd [0].fd = ctx->tcp_listener;
    pfd [0].events = POLLIN;
    pfd [1].fd = ctx->ipc_listener;
    pfd [1].events = POLLIN;

    while (1) {

        /*  Wait for events. */
        rc = poll (pfd, 2, -1);
        errno_assert (rc >= 0);
        nn_assert (rc != 0);

        /*  There's an incoming TCP connection. */
        if (pfd [0].revents & POLLIN) {

            /*  Accept the connection. */
            conn = accept (ctx->tcp_listener, NULL, NULL);
            if (conn < 0 && errno == ECONNABORTED)
                continue;
            errno_assert (conn >= 0);
            tv.tv_sec = 0;
            tv.tv_usec = 100000;
            rc = setsockopt (conn, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof (tv));
            errno_assert (rc == 0);
            rc = setsockopt (conn, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof (tv));
            errno_assert (rc == 0);

            /*  Read TCPMUX header. */
            pos = 0;
            while (1) {
                nn_assert (pos < sizeof (service));
                ssz = recv (conn, &service [pos], 1, 0);
                if (ssz < 0 && errno == EAGAIN) {
                    close (conn);
                    continue;
                }
                errno_assert (ssz >= 0);
                nn_assert (ssz == 1);
                service [pos] = tolower (service [pos]);
                if (pos > 0 && service [pos - 1] == 0x0d &&
                      service [pos] == 0x0a)
                    break;
                ++pos;
            }
            service [pos - 1] = 0;
            
            /*  Check whether specified service is listening. */
            for (it = nn_list_begin (&ctx->conns);
                  it != nn_list_end (&ctx->conns);
                  it = nn_list_next (&ctx->conns, it)) {
                tc = nn_cont (it, struct nn_tcpmuxd_conn, item);
                if (strcmp (service, tc->service) == 0)
                    break;
            }

            /* If no one is listening, tear down the connection. */
            if (it == nn_list_end (&ctx->conns)) {
                ssz = send (conn, "-\x0d\x0a", 3, 0);
                if (ssz < 0 && errno == EAGAIN) {
                    close (conn);
                    continue;
                }
                errno_assert (ssz >= 0);
                nn_assert (ssz == 3);
                close (conn);
                continue;
            }

            /*  Send TCPMUX reply. */
            ssz = send (conn, "+\x0d\x0a", 3, 0);
            if (ssz < 0 && errno == EAGAIN) {
                close (conn);
                continue;
            }
            errno_assert (ssz >= 0);
            nn_assert (ssz == 3);

            /*  Pass the file descriptor to the listening process. */
            rc = send_fd (tc->fd, conn);
            errno_assert (rc == 0);
        }

        /*  There's an incoming IPC connection. */
        if (pfd [1].revents & POLLIN) {

            /*  Accept the connection. */
            conn = accept (ctx->ipc_listener, NULL, NULL);
            if (conn < 0 && errno == ECONNABORTED)
                continue;
            errno_assert (conn >= 0);

            /*  Create new connection entry. */
            tc = nn_alloc (sizeof (struct nn_tcpmuxd_conn), "tcpmuxd_conn");
            nn_assert (tc);
            tc->fd = conn;
            nn_list_item_init (&tc->item);    

            /*  Read the connection header. */
            ssz = recv (conn, buf, 2, 0);
            errno_assert (ssz >= 0);
            nn_assert (ssz == 2);
            sz = nn_gets (buf);
            tc->service = nn_alloc (sz + 1, "tcpmuxd_conn.service");
            nn_assert (tc->service);
            ssz = recv (conn, tc->service, sz, 0);
            errno_assert (ssz >= 0);
            nn_assert (ssz == sz);
            for (i = 0; i != sz; ++i)
                tc->service [sz] = tolower (tc->service [sz]);
            tc->service [sz] = 0;
            
            /*  Add the entry to the IPC connections list. */
            nn_list_insert (&ctx->conns, &tc->item, nn_list_end (&ctx->conns));
        }
    }
}

/*  Send file descriptor fd to IPC socket s. */
static int send_fd (int s, int fd)
{
    int rc;
    struct iovec iov;
    char c = 0;
    struct msghdr msg;
    char control [sizeof (struct cmsghdr) + 10];
    struct cmsghdr *cmsg;

    /*  Compose the message. We'll send one byte long dummy message
        accompanied with the fd.*/
    iov.iov_base = &c;
    iov.iov_len = 1;
    memset (&msg, 0, sizeof (msg));
    msg.msg_iov = &iov;
    msg.msg_iovlen = 1;
    msg.msg_control = control;
    msg.msg_controllen = sizeof (control);

    /*  Attach the file descriptor to the message. */
    cmsg = CMSG_FIRSTHDR (&msg);
    cmsg->cmsg_level = SOL_SOCKET;
    cmsg->cmsg_type = SCM_RIGHTS;
    cmsg->cmsg_len = CMSG_LEN (sizeof (fd));
    int *data = (int*) CMSG_DATA (cmsg);
    *data = fd;

    /*  Adjust the size of the control to match the data. */
    msg.msg_controllen = cmsg->cmsg_len;

    /*  Pass the file descriptor to the registered process. */
    rc = sendmsg (s, &msg, 0);
    if (rc < 0)
        return -1;
    nn_assert (rc == 1);

    /*  Sending the file descriptor to other process acts as dup().
        Therefore, we have to close the local copy of the file descriptor. */
    nn_closefd (fd);

    return 0;
}

#endif
