"""
-IMU analysis example-
Description: Contains only the basics for the analysis of IMU data
For a more comprehensive package refer to: https://gitlab.com/Rickdkk/worklab
Author:     Rick de Klerk
Contact:    r.de.klerk@umcg.nl
Company:    University Medical Center Groningen
License:    GNU GPLv3.0
Date:       27/06/2019
"""
import re
from glob import glob
from os import listdir, path
from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.integrate import cumtrapz
from scipy.interpolate import interp1d
from scipy.signal import filtfilt, butter


def load_session(root_dir: str, filenames: list = None) -> dict:
    """Imports NGIMU session in nested dictionary with all devices and sensors. Translated from xio-Technologies.
    https://github.com/xioTechnologies/NGIMU-MATLAB-Import-Logged-Data-Example

    :param root_dir: directory where session is located
    :param filenames: Optional - list of sensor names or single sensor name that you would like to include
    :return: returns nested object sensordata[device][sensor][dataframe]
    """
    directory_contents = listdir(root_dir)  # all content in directory
    if not directory_contents:
        raise Exception("No contents in directory")
    if "Session.xml" not in directory_contents:
        raise Exception("Session.xml not found.")
    directories = glob(f"{root_dir}/*/")  # folders of all devices
    session_data = dict()

    for sensordir in directories:  # loop through all sensor directories
        sensor_files = glob(f"{sensordir}/*.csv")
        device_name = path.split(path.split(sensordir)[0])[-1].split(" ")[3]
        device_name = "Left" if "Links" in device_name or "Left" in device_name else device_name
        device_name = "Right" if "Rechts" in device_name or "Right" in device_name else device_name
        device_name = "Frame" if "Frame" in device_name else device_name
        session_data[device_name] = dict()

        for sensor_file in sensor_files:  # loop through all csv files
            sensor_name = path.split(sensor_file)[-1].split(".csv")[0]  # sensor without path or extension

            if filenames and sensor_name not in filenames:
                continue  # skip if filenames is given and sensor not in filenames

            session_data[device_name][sensor_name] = pd.read_csv(sensor_file).drop_duplicates()
            session_data[device_name][sensor_name].rename(columns=lambda x: re.sub("[(\[].*?[)\]]", "", x)
                                                          .replace(" ", ""), inplace=True)  # remove units from name

        if not session_data[device_name]:
            raise Exception("No data was imported")
    return session_data


def resample_imu(sessiondata, samplefreq=400):
    """ Resample all devices and sensors to new sample frequency. Translated from xio-Technologies.
    :param sessiondata: original sessiondata structure
    :param samplefreq: new intended sample frequency
    :return: resampled sessiondata
    """
    end_time = 0
    for device in sessiondata:
        for sensor in sessiondata[device]:
            max_time = sessiondata[device][sensor]["Time"].max()
            end_time = max_time if max_time > end_time else end_time

    new_time = np.arange(0, end_time, 1/samplefreq)

    for device in sessiondata:
        for sensor in sessiondata[device]:
            if sensor == "quaternion":  # TODO: xio-tech has TODO here to replace this part with slerp
                sessiondata[device][sensor] = pd_interp(sessiondata[device][sensor], "Time", new_time)
                sessiondata[device][sensor] *= (1 / np.linalg.norm(sessiondata[device][sensor], axis=0))
            elif sensor == "matrix":
                sessiondata[device].pop(sensor)
                warn("Rotation matrix cannot be resampled. This dataframe has been removed")
            else:
                sessiondata[device][sensor] = pd_interp(sessiondata[device][sensor], "Time", new_time)
    return sessiondata


def calc_wheelspeed(sessiondata, camber=15, wsize=0.31, wbase=0.60, sfreq=400):
    """ Calculate wheelchair velocity based on NGIMU data, modifies dataframes inplace.
    :param sessiondata: original sessiondata structure
    :param camber: camber angle in degrees
    :param wsize: radius of the wheels
    :param wbase: width of wheelbase
    :param sfreq: sample frequency
    """
    frame = sessiondata["Frame"]["sensors"]  # view into dataframe, edits will be inplace
    left = sessiondata["Left"]["sensors"]  # most variables will be added to df except for some temp variables
    right = sessiondata["Right"]["sensors"]

    # Wheelchair camber correction
    deg2rad = np.pi / 180
    right["GyroCor"] = right["GyroscopeY"] + np.tan(camber * deg2rad) * (frame["GyroscopeZ"] * np.cos(camber * deg2rad))
    left["GyroCor"] = left["GyroscopeY"] - np.tan(camber * deg2rad) * (frame["GyroscopeZ"] * np.cos(camber * deg2rad))
    frame["GyroCor"] = (right["GyroCor"] + left["GyroCor"]) / 2

    # Calculation of wheelspeed and displacement
    right["GyroVel"] = right["GyroCor"] * wsize * deg2rad  # angular velocity to linear velocity
    right["GyroDist"] = cumtrapz(right["GyroVel"] / sfreq, initial=0.0)  # integral of velocity gives distance

    left["GyroVel"] = left["GyroCor"] * wsize * deg2rad
    left["GyroDist"] = cumtrapz(left["GyroVel"] / sfreq, initial=0.0)

    frame["CombVel"] = (right["GyroVel"] + left["GyroVel"]) / 2  # mean velocity
    frame["CombDist"] = (right["GyroDist"] + left["GyroDist"]) / 2  # mean velocity

    """Perform skid correction from Rienk vd Slikke, please refer and reference to: Van der Slikke, R. M. A., et. al. 
    Wheel skid correction is a prerequisite to reliably measure wheelchair sports kinematics based on inertial sensors. 
    Procedia Engineering, 112, 207-212."""
    frame["CombVelRight"] = np.gradient(right["GyroDist"]) * sfreq  # Calculate frame centre displacement
    frame["CombVelRight"] -= np.tan(np.deg2rad(frame["GyroscopeZ"]/sfreq)) * wbase/2 * sfreq
    frame["CombVelLeft"] = np.gradient(left["GyroDist"]) * sfreq
    frame["CombVelLeft"] += np.tan(np.deg2rad(frame["GyroscopeZ"]/sfreq)) * wbase/2 * sfreq

    r_ratio0 = np.abs(right["GyroVel"]) / (np.abs(right["GyroVel"]) + np.abs(left["GyroVel"]))  # Ratio left and right
    l_ratio0 = np.abs(left["GyroVel"]) / (np.abs(right["GyroVel"]) + np.abs(left["GyroVel"]))
    r_ratio1 = np.abs(np.gradient(left["GyroVel"])) / (np.abs(np.gradient(right["GyroVel"]))
                                                       + np.abs(np.gradient(left["GyroVel"])))
    l_ratio1 = np.abs(np.gradient(right["GyroVel"])) / (np.abs(np.gradient(right["GyroVel"]))
                                                        + np.abs(np.gradient(left["GyroVel"])))

    comb_ratio = (r_ratio0 * r_ratio1) / ((r_ratio0 * r_ratio1) + (l_ratio0 * l_ratio1))  # Combine speed ratios
    comb_ratio = lowpass_butter(comb_ratio, sfreq=sfreq, co=20)  # Filter the signal
    comb_ratio = np.clip(comb_ratio, 0, 1)  # clamp Combratio values, not in df
    frame["CombSkidVel"] = (frame["CombVelRight"] * comb_ratio) + (frame["CombVelLeft"] * (1-comb_ratio))
    frame["CombSkidDist"] = cumtrapz(frame["CombSkidVel"], initial=0.0) / sfreq  # Combined skid displacement
    return sessiondata


def lowpass_butter(array, sfreq, co=20, order=2):
    """Butterworth filter that takes sample-freq, cutoff, and order as input."""
    # noinspection PyTupleAssignmentBalance
    b, a = butter(order, co / (0.5 * sfreq), 'low')
    return filtfilt(b, a, array)


def pd_interp(df, interp_column, at):
    """ Resamples DataFrame with Scipy's interp1d
    :param df: original sessiondata structure
    :param interp_column: column to interpolate on
    :param at: column to interpolate on
    :return: interpolated DataFrame
    """
    interp_df = pd.DataFrame()
    for col in df:
        f = interp1d(df[interp_column], df[col], bounds_error=False, fill_value="extrapolate")
        interp_df[col] = f(at)
    interp_df[interp_column] = at
    return interp_df


def main():
    test_data = load_session("ngimu_example_data", filenames=["sensors"])
    test_data = resample_imu(test_data, samplefreq=400)
    test_data = calc_wheelspeed(test_data)

    time = test_data["Frame"]["sensors"]["Time"]
    speed = test_data["Frame"]["sensors"]["CombSkidVel"]

    plt.plot(time, speed)
    plt.autoscale(tight=True)
    plt.xlabel("Time (s)")
    plt.ylabel("Skid corrected speed (m/s)")
    plt.show()


if __name__ == "__main__":
    main()
