# coding: utf-8

import sys

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as hc
import scipy.spatial as sp
import sklearn.model_selection
import torch

plt.style.use("ggplot")
from datetime import datetime

import dataset
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from torch import nn, optim
from torch.autograd import Variable

CUDA = False
BATCH_SIZE = 64
EPOCHS = 1000
LOG_INTERVAL = 100

# https://matplotlib.org/examples/color/named_colors.html
COLORS = ["darkred", "darkgreen", "darkblue", "mediumvioletred", "darkgoldenrod"]


def to_var(x):
    x = Variable(x)
    if CUDA:
        x = x.cuda()
    return x


def one_hot(labels, class_size):
    targets = torch.zeros(labels.size(0), class_size)
    for i, label in enumerate(labels):
        targets[i, label] = 1
    return to_var(targets)


class CVAE(nn.Module):
    def __init__(self, feature_size, hidden_size, latent_size, class_size):
        super(CVAE, self).__init__()
        self.feature_size = feature_size
        self.class_size = class_size

        # encode
        self.fc1 = nn.Linear(feature_size + class_size, hidden_size)
        self.fc21 = nn.Linear(hidden_size, latent_size)  # mu
        self.fc22 = nn.Linear(hidden_size, latent_size)  # logvar

        # decode
        self.fc3 = nn.Linear(latent_size + class_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, feature_size)

        self.relu = nn.ReLU()

    def encode(self, x, c):  # Q(z|x, c)
        """
        x: (bs, feature_size)
        c: (bs, class_size)
        """
        inputs = torch.cat([x, c], 1)  # (bs, feature_size+class_size)
        h1 = self.relu(self.fc1(inputs))
        z_mu = self.fc21(h1)
        z_var = self.fc22(h1)
        return z_mu, z_var

    def reparametrize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std) + mu
        else:
            return mu

    def decode(self, z, c):  # P(x|z, c)
        """
        z: (bs, latent_size)
        c: (bs, class_size)
        """
        inputs = torch.cat([z, c], 1)  # (bs, latent_size+class_size)

        h3 = self.fc3(inputs)
        return self.fc4(h3)

    def forward(self, x, c):
        mu, logvar = self.encode(x.view(-1, self.feature_size), c)
        z = self.reparametrize(mu, logvar)
        return self.decode(z, c), mu, logvar


def loss_function(recon_x, x, mu, logvar) -> Variable:
    # BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, x.shape[1]))
    MSE = nn.functional.mse_loss(recon_x, x.view(-1, x.shape[1]))
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return MSE, KLD


def plot(MSE, KLD, fname):
    fig = plt.figure(constrained_layout=True)
    gs = gridspec.GridSpec(1, 2, figure=fig)

    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(np.arange(len(MSE)), MSE)
    ax1.set_xlabel("itr")
    ax1.set_ylabel("MSE")

    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(np.arange(len(KLD)), KLD)
    ax2.set_xlabel("itr")
    ax2.set_ylabel("KLD")

    plt.savefig(fname)


def train(model, train_loader, optimizer, class_size, epoch):
    model.train()
    train_loss = 0
    MSEs = []
    KLDs = []
    for batch_idx, (data, labels) in enumerate(train_loader):
        data = to_var(data)
        labels = one_hot(labels, class_size)
        recon_batch, mu, logvar = model(data, labels)
        optimizer.zero_grad()
        MSE, KLD = loss_function(recon_batch, data, mu, logvar)
        MSEs.append(MSE.item())
        KLDs.append(KLD.item())
        loss = MSE + KLD
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item() / len(data),
                )
            )

    print(
        "====> Epoch: {} Average loss: {:.4f}".format(
            epoch, train_loss / len(train_loader.dataset)
        )
    )

    return MSEs, KLDs


def test(model, test_loader, class_size, epoch):
    model.eval()
    test_loss = 0
    MSEs = []
    KLDs = []
    with torch.no_grad():
        for i, (data, labels) in enumerate(test_loader):
            data = to_var(data)
            labels = one_hot(labels, class_size)
            recon_batch, mu, logvar = model(data, labels)
            MSE, KLD = loss_function(recon_batch, data, mu, logvar)
            MSEs.append(MSE.item())
            KLDs.append(KLD.item())
            test_loss += MSE.item() + KLD.item()

        test_loss /= len(test_loader.dataset)
        print("====> Test set loss: {:.4f}".format(test_loss))

    return MSEs, KLDs


def assign_colors(categorical):
    categorical_to_number, labels = pd.factorize(categorical)
    color_map = [COLORS[i] for i in categorical_to_number]

    return color_map, labels


def heatmap(X, y, row_linkage, col_linkage):
    colors, labels = assign_colors(y)
    g = sns.clustermap(
        X,
        figsize=(16, 16),
        z_score=0,
        row_linkage=row_linkage,
        col_linkage=col_linkage,
        col_colors=colors,
    )

    for i in range(len(labels)):
        g.ax_col_dendrogram.bar(0, 0, color=COLORS[i], label=labels[i], linewidth=0)
        g.ax_col_dendrogram.legend(loc="best", ncol=len(labels))

    plt.show()


def main(args):
    df = pd.read_csv(args[0], index_col=0)
    X = df.iloc[:, :-2].values
    X = X.astype(np.float32)
    Y = df.iloc[:, -2].fillna("normal").values
    y, classes = pd.factorize(Y)

    hidden_size = 400
    latent_size = 50

    mode = "generate"
    if mode == "training":
        X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
            X, y, test_size=0.2
        )
        train_data = dataset.Data(X_train, y_train, classes)
        test_data = dataset.Data(X_test, y_test, classes)

        kwargs = {"num_workers": 1, "pin_memory": True} if CUDA else {}

        train_loader = torch.utils.data.DataLoader(
            dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, **kwargs
        )
        test_loader = torch.utils.data.DataLoader(
            dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, **kwargs
        )

        model = CVAE(X.shape[1], hidden_size, latent_size, len(classes))
        if CUDA:
            model.cuda()

        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        training_MSEs = []
        testing_MSEs = []
        training_KLDs = []
        testing_KLDs = []
        for epoch in range(1, EPOCHS + 1):
            MSEs, KLDs = train(model, train_loader, optimizer, len(classes), epoch)
            training_MSEs.append(MSEs)
            training_KLDs.append(KLDs)
            MSE, KLD = test(model, test_loader, len(classes), epoch)
            testing_MSEs.append(MSE)
            testing_KLDs.append(KLD)

        plot(
            np.array(training_MSEs).flatten(),
            np.array(training_KLDs).flatten(),
            "training_loss.png",
        )
        plot(
            np.array(testing_MSEs).flatten(),
            np.array(testing_KLDs).flatten(),
            "testing_loss.png",
        )
        torch.save(
            model.state_dict(),
            "model_{}.pt".format(datetime.now().strftime("%m-%d-%Y_%H:%M:%S")),
        )
    elif mode == "generate":
        model = CVAE(X.shape[1], hidden_size, latent_size, len(classes))
        model.load_state_dict(torch.load(args[1]))
        c = torch.eye(len(classes), len(classes))
        c = to_var(c)

        number_of_samples = 500
        samples = np.array([]).reshape(-1, X.shape[1])

        for i in range(number_of_samples):
            z = to_var(torch.randn(len(classes), latent_size))
            samples_ = model.decode(z, c).data.cpu().numpy()
            samples = np.concatenate((samples, samples_))

        samples = pd.DataFrame(samples)
        samples["classes"] = np.tile(classes, number_of_samples)
        std_vals = StandardScaler().fit_transform(samples.iloc[:, :-1])
        sample_dist = sp.distance.pdist(std_vals, "euclidean")
        feature_dist = sp.distance.pdist(std_vals.T, "euclidean")

        sample_linkage = hc.linkage(sample_dist, method="average")
        feature_linkage = hc.linkage(feature_dist, method="average")

        heatmap(std_vals.T, samples["classes"].values, feature_linkage, sample_linkage)
        print("Saving generated samples...")
        samples.to_csv("generated_samples.csv")
    else:
        std_vals = StandardScaler().fit_transform(X)
        sample_dist = sp.distance.pdist(std_vals, "euclidean")
        feature_dist = sp.distance.pdist(std_vals.T, "euclidean")

        sample_linkage = hc.linkage(sample_dist, method="average")
        feature_linkage = hc.linkage(feature_dist, method="average")
        heatmap(std_vals.T, Y, feature_linkage, sample_linkage)


if __name__ == "__main__":
    main(sys.argv[1:])
