"""
-Ergometer data analysis example-
Description: Contains only the basics for the analysis of ergometer data
For a more comprehensive package refer to: https://gitlab.com/Rickdkk/worklab
Author:     Rick de Klerk
Contact:    r.de.klerk@umcg.nl
Company:    University Medical Center Groningen
License:    GNU GPLv3.0
Date:       27/06/2019
"""
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.integrate import cumtrapz
from scipy.signal import butter, filtfilt


def load_esseda(filename: str) -> dict:
    """ Loads HSB ergometer data from LEM datafile

    :param filename: full file path or file in existing path from LEM excel sheet
    :return: dictionary with ergometer data in dataframes
    """
    df = pd.read_excel(filename, sheet_name="HSB")
    df = df.dropna(axis=1, how='all')  # remove empty columns
    df = df.apply(lambda col: pd.to_numeric(col.str.replace(',', '.')) if isinstance(col[0], str) else col, axis=0)

    cols = len(df.columns) / 5  # LEM does this annoying thing where it starts in new columns
    mats = np.split(df.values, int(cols), axis=1)
    dmat = np.concatenate(tuple(mats), axis=0)

    data = {"left": pd.DataFrame(), "right": pd.DataFrame()}
    data["left"]["time"] = dmat[:, 0]
    data["left"]["uforce"] = dmat[:, 1]
    data["left"]["speed"] = dmat[:, 3]
    data["right"]["time"] = dmat[:, 0]
    data["right"]["uforce"] = dmat[:, 2]
    data["right"]["speed"] = dmat[:, 4]

    for side in data:
        data[side].dropna(inplace=True)
    return data


def filter_data(data, co_f=15, ord_f=2, force=True, co_s=6, ord_s=2, speed=True):
    """ Filters measurement wheel data; should be used before processing

    :param data: measurement wheel data
    :param co_f: cut off frequency for force related variables
    :param ord_f: filter order for force related variables
    :param force: filter force toggle
    :param co_s: cut off frequency for speed related variables
    :param ord_s: filter order for speed related variables
    :param speed: filter speed toggle
    :return: same data but filtered
    """
    sfreq = 100  # overrides default; ergometer is always 100Hz
    for side in data:
        if force:
            b, a = butter(ord_f, co_f / (0.5 * sfreq), 'low')
            if "force" in data[side]:  # not from LEM
                data[side]["force"] = filtfilt(b, a, data[side]["force"])
            else:  # from LEM
                data[side]["uforce"] = filtfilt(b, a, data[side]["uforce"])
        if speed:
            b, a = butter(ord_s, co_s / (0.5 * sfreq), 'low')
            data[side]["speed"] = filtfilt(b, a, data[side]["speed"])
    return data


def process_data(data: dict, wheelsize: float = 0.31, rimsize: float = 0.27) -> dict:
    """ Does all basic calculations and conversions on measurement wheel data

    :param data: filtered measurement wheel data
    :param wheelsize: radius of wheelchair wheel
    :param rimsize: handrim radius
    :return: extended dataframe with regular outcome values
    """
    sfreq = 100
    for side in data:
        if "uforce" in data[side]:  # LEM
            data[side]["force"] = data[side]["uforce"] / (wheelsize/rimsize)
        data[side]["torque"] = data[side]["force"] * wheelsize
        data[side]["acc"] = np.gradient(data[side]["speed"]) * sfreq
        data[side]["power"] = data[side]["speed"] * data[side]["force"]
        data[side]["dist"] = cumtrapz(data[side]["speed"], initial=0.0) / sfreq
        data[side]["work"] = data[side]["power"] / sfreq
        data[side]["uforce"] = data[side]["force"] * (wheelsize/rimsize)
        data[side]["aspeed"] = data[side]["speed"] / wheelsize
        data[side]["angle"] = cumtrapz(data[side]["aspeed"], initial=0.0) / sfreq
    return data


def find_peaks(data: pd.Series, cutoff: float = 1.0, minpeak: float = 5.0) -> pd.DataFrame:
    """Finds positive peaks in signal and returns indices

    :param data: any signal that contains peaks above minpeak that dip below cutoff
    :param cutoff: where the peak gets cut off at the bottom, basically a hysteresis band
    :param minpeak: minimum peak height of wave
    :return: nested dictionary with start, end, and peak **index** of each peak"""

    peaks = defaultdict(list)
    tmp = dict()

    data = data.values if isinstance(data, pd.Series) else data
    data_slice = np.nonzero(data > minpeak)
    for prom in np.nditer(data_slice):
        if peaks["stop"]:
            if prom < peaks["stop"][-1]:
                continue  # skip if a push has already been found for that index
        tmp["stop"] = next((idx for idx, value in enumerate(data[prom:]) if value < cutoff), None)
        tmp["start"] = next((idx for idx, value in enumerate(reversed(data[:prom])) if value < cutoff), None)
        if tmp["stop"] and tmp["start"]:  # did we find a start and stop?
            peaks["stop"].append(tmp["stop"] + prom - 1)
            peaks["start"].append(prom - tmp["start"])
    peaks["peak"] = [np.argmax(data[start:stop + 1]) + start for start, stop in zip(peaks["start"], peaks["stop"])]
    peaks = {key: np.unique(peak) for key, peak in peaks.items()}
    return pd.DataFrame(peaks)


# noinspection PyTypeChecker
def push_by_push(data: dict, pushes: dict) -> dict:
    """Calculates push-by-push statistics such as push time and power per push.

    :param data: processed measurement wheel data
    :param pushes: start, end, peak indices of all pushes
    :return: push-by-push dataframe with outcome parameters in arrays"""

    pbp = {"left": [], "right": []}
    for side in data:  # left and right side
        pbp[side] = defaultdict(list)
        pbp[side]["start"] = pushes[side]["start"]
        pbp[side]["stop"] = pushes[side]["stop"]
        pbp[side]["peak"] = pushes[side]["peak"]
        for ind, (start, stop, peak) in enumerate(zip(pbp[side]["start"], pbp[side]["stop"], pbp[side]["peak"])):
            pbp[side]["tstart"].append(data[side]["time"][start])
            pbp[side]["tstop"].append(data[side]["time"][stop])
            pbp[side]["tpeak"].append(data[side]["time"][peak])
            pbp[side]["cangle"].append(data[side]["angle"][stop] - data[side]["angle"][start])
            stop += 1  # inclusive of last sample for slicing
            pbp[side]["ptime"].append(pbp[side]["tstop"][-1] - pbp[side]["tstart"][-1])
            pbp[side]["pout"].append(np.mean(data[side]["power"][start:stop]))
            pbp[side]["maxpout"].append(np.max(data[side]["power"][start:stop]))
            pbp[side]["maxtorque"].append(np.max(data[side]["torque"][start:stop]))
            pbp[side]["meantorque"].append(np.mean(data[side]["torque"][start:stop]))
            pbp[side]["work"].append(np.cumsum(data[side]["work"][start:stop]).iloc[-1])
            pbp[side]["fpeak"].append(np.max(data[side]["uforce"][start:stop]))
            pbp[side]["fmean"].append(np.mean(data[side]["uforce"][start:stop]))
            pbp[side]["slope"].append(pbp[side]["maxtorque"][-1] /
                                      (pbp[side]["tpeak"][-1] - pbp[side]["tstart"][-1]))
            if start != pushes[side]["start"][0]:  # only after first push
                pbp[side]["ctime"].append(pbp[side]["tstart"][-1] - pbp[side]["tstart"][-2])
                pbp[side]["reltime"].append(pbp[side]["ptime"][-2] / pbp[side]["ctime"][-1] * 100)
        pbp[side]["ctime"].append(np.NaN)
        pbp[side]["reltime"].append(np.NaN)
        pbp[side] = pd.DataFrame(pbp[side])
    return pbp


def main():
    test_data = load_esseda("ergometer_example_data.xls")
    test_data = filter_data(test_data)
    test_data = process_data(test_data)
    test_pushes = dict()
    for side in ["left", "right"]:
        test_pushes[side] = find_peaks(test_data[side]["torque"], cutoff=0.1)

    test_pushes = push_by_push(test_data, test_pushes)

    fig, axes = plt.subplots(2, 1, sharex="all", sharey="all")
    for side, ax in zip(["left", "right"], axes):
        ax.plot(test_data[side]["time"], test_data[side]["torque"])
        ax.plot(test_pushes[side]["tstart"], test_data[side]["torque"][test_pushes[side]["start"]], "C1o")
        ax.plot(test_pushes[side]["tstop"], test_data[side]["torque"][test_pushes[side]["stop"]], "C1o")

    plt.autoscale(tight=True)
    axes[1].set_xlabel("Time (s)")
    axes[0].set_ylabel("Torque (Nm)")
    axes[1].set_ylabel("Torque (Nm)")
    fig.align_ylabels()
    plt.show()


if __name__ == "__main__":
    main()
