"""
-Measurement wheel analysis example-
Description: Contains only the basics for the analysis of measurement wheel data
For a more comprehensive package refer to: https://gitlab.com/Rickdkk/worklab
Author:     Rick de Klerk
Contact:    r.de.klerk@umcg.nl
Company:    University Medical Center Groningen
License:    GNU GPLv3.0
Date:       27/06/2019
"""
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.integrate import cumtrapz
from scipy.signal import butter, filtfilt


def load_opti(filename: str) -> pd.DataFrame:
    """Loads Optipush data from .data file

    :param filename: filename or path to Optipush file
    :return: dataframe with 3D kinetics data"""

    names = ["time", "fx", "fy", "fz", "mx", "my", "torque", "angle"]
    dtypes = {name: np.float64 for name in names}
    usecols = [0, 3, 4, 5, 6, 7, 8, 9]
    opti_df = pd.read_csv(filename, names=names, delimiter="\t", usecols=usecols, dtype=dtypes, skiprows=12)
    opti_df["angle"] *= (np.pi / 180)
    opti_df["torque"] *= -1
    return opti_df


def load_sw(filename: str, sfreq: int = 200) -> pd.DataFrame:
    """Loads SMARTwheel data from .csv file

    :param filename: filename or path to SMARTwheel data
    :param sfreq: samplefreq, this can be changed, default is 200
    :return: dataframe with 3D kinetics data"""

    names = ["time", "fx", "fy", "fz", "mx", "my", "torque", "angle"]
    dtypes = {name: np.float64 for name in names}
    usecols = [1, 18, 19, 20, 21, 22, 23, 3]
    sw_df = pd.read_csv(filename, names=names, usecols=usecols, dtype=dtypes)
    sw_df["time"] /= sfreq
    sw_df["angle"] = np.unwrap(sw_df["angle"] * (np.pi / 180)) * - 1  # in radians
    return sw_df


def filter_data(data, sfreq=200, co_f=15, ord_f=2, force=True, co_s=6, ord_s=2, speed=True):
    """Filters measurement wheel data; should be used before processing

    :param data: measurement wheel data
    :param sfreq:  specific sample freq for measurement wheel
    :param co_f: cut off frequency for force related variables
    :param ord_f: filter order for force related variables
    :param force: filter force toggle
    :param co_s: cut off frequency for speed related variables
    :param ord_s: filter order for speed related variables
    :param speed: filter speed toggle
    :return: same data but filtered"""

    if force:
        frel = ["fx", "fy", "fz", "mx", "my", "torque"]
        for var in frel:
            b, a = butter(ord_f, co_f / (0.5 * sfreq), 'low')
            data[var] = filtfilt(b, a, data[var])
    if speed:
        b, a = butter(ord_s, co_s / (0.5 * sfreq), 'low')
        data["angle"] = filtfilt(b, a, data["angle"])
    return data


def process_data(data, wheelsize=0.31, rimsize=0.27, sfreq=200):
    """Does all basic calculations and conversions on measurement wheel data

    :param data: filtered measurement wheel data
    :param wheelsize: radius of wheelchair wheel
    :param rimsize: handrim radius
    :param sfreq: specific sample frequency of measurement wheel
    :return: extended dataframe with regular outcome values"""

    data["aspeed"] = np.gradient(data["angle"]) / (1 / sfreq)
    data["speed"] = data["aspeed"] * wheelsize
    data["dist"] = cumtrapz(data["speed"], initial=0.0) * (1 / sfreq)
    data["acc"] = np.gradient(data["speed"]) / (1 / sfreq)
    data["ftot"] = (data["fx"] ** 2 + data["fy"] ** 2 + data["fz"] ** 2) ** 0.5
    data["uforce"] = data["torque"] / rimsize
    data["force"] = data["uforce"] / (wheelsize / rimsize)
    data["power"] = data["torque"] * data["aspeed"]
    data["work"] = data["power"] * (1 / sfreq)
    return data


def find_peaks(data: pd.Series, cutoff: float = 1.0, minpeak: float = 5.0) -> pd.DataFrame:
    """Finds positive peaks in signal and returns indices

    :param data: any signal that contains peaks above minpeak that dip below cutoff
    :param cutoff: where the peak gets cut off at the bottom, basically a hysteresis band
    :param minpeak: minimum peak height of wave
    :return: nested dictionary with start, end, and peak **index** of each peak"""

    peaks = defaultdict(list)
    tmp = dict()

    data = data.values if isinstance(data, pd.Series) else data
    data_slice = np.nonzero(data > minpeak)
    for prom in np.nditer(data_slice):
        if peaks["stop"]:
            if prom < peaks["stop"][-1]:
                continue  # skip if a push has already been found for that index
        tmp["stop"] = next((idx for idx, value in enumerate(data[prom:]) if value < cutoff), None)
        tmp["start"] = next((idx for idx, value in enumerate(reversed(data[:prom])) if value < cutoff), None)
        if tmp["stop"] and tmp["start"]:  # did we find a start and stop?
            peaks["stop"].append(tmp["stop"] + prom - 1)
            peaks["start"].append(prom - tmp["start"])
    peaks["peak"] = [np.argmax(data[start:stop + 1]) + start for start, stop in zip(peaks["start"], peaks["stop"])]
    peaks = {key: np.unique(peak) for key, peak in peaks.items()}
    return pd.DataFrame(peaks)


# noinspection PyTypeChecker
def push_by_push(data: pd.DataFrame, pushes: pd.DataFrame) -> pd.DataFrame:
    """Calculates push-by-push statistics such as push time and power per push.

    :param data: processed measurement wheel data
    :param pushes: start, end, peak indices of all pushes
    :return: push-by-push dataframe with outcome parameters in arrays"""

    pbp = defaultdict(list)
    pbp["start"] = pushes["start"]
    pbp["stop"] = pushes["stop"]
    pbp["peak"] = pushes["peak"]

    for ind, (start, stop, peak) in enumerate(zip(pbp["start"], pbp["stop"], pbp["peak"])):  # for each push

        pbp["tstart"].append(data["time"][start])
        pbp["tstop"].append(data["time"][stop])
        pbp["tpeak"].append(data["time"][peak])
        pbp["cangle"].append(data["angle"][stop] - data["angle"][start])
        stop += 1  # inclusive of last sample for slices
        pbp["ptime"].append(pbp["tstop"][-1] - pbp["tstart"][-1])
        pbp["pout"].append(np.mean(data["power"][start:stop]))
        pbp["maxpout"].append(np.max(data["power"][start:stop]))
        pbp["maxtorque"].append(np.max(data["torque"][start:stop]))
        pbp["meantorque"].append(np.mean(data["torque"][start:stop]))
        pbp["work"].append(np.cumsum(data["work"][start:stop]).iloc[-1])
        pbp["fpeak"].append(np.max(data["uforce"][start:stop]))
        pbp["fmean"].append(np.mean(data["uforce"][start:stop]))
        pbp["feff"].append(np.mean(data["uforce"][start:stop] / data["ftot"][start:stop]) * 100)
        pbp["slope"].append(pbp["maxtorque"][-1] / (pbp["tpeak"][-1] - pbp["tstart"][-1]))
        if start != pushes["start"][0]:  # only after first push
            pbp["ctime"].append(pbp["tstart"][-1] - pbp["tstart"][-2])
            pbp["reltime"].append(pbp["ptime"][-2] / pbp["ctime"][-1] * 100)
    pbp["ctime"].append(np.NaN)
    pbp["reltime"].append(np.NaN)
    return pd.DataFrame(pbp)


def main():
    test_data = load_opti("measurement_wheel_example_data.data")
    test_data = filter_data(test_data)
    test_data = process_data(test_data)
    test_pushes = find_peaks(test_data["torque"], cutoff=0.1)
    test_pushes = push_by_push(test_data, test_pushes)

    plt.plot(test_data["time"], test_data["torque"])
    plt.plot(test_pushes["tstart"], test_data["torque"][test_pushes["start"]], "C1o")
    plt.plot(test_pushes["tstop"], test_data["torque"][test_pushes["stop"]], "C1o")
    plt.autoscale(tight=True)
    plt.xlabel("Time (s)")
    plt.ylabel("Torque (Nm)")
    plt.show()


if __name__ == "__main__":
    main()
