#!/usr/bin/env python3
'''
This script fits features to qEASL viable
volume and produces a predictive model. See
help for more information.
'''

import nibabel, argparse, tabulate, matplotlib, sys
import numpy as np
import matplotlib.pyplot as plt
from skimage import measure
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier

def load_tsv(tsv_file, array_valued=True, column_filter=None):

    if column_filter:
        column_filter = [int(column) for column in column_filter.split(',')]

    values_by_id = {}
    with open(tsv_file) as f:
        for i, l in enumerate(f):
            l = l.strip().split('\t')

            if not column_filter:
                column_filter = list(np.array(range(len(l) - 1)) + 1)
                print(column_filter)

            if array_valued:
                values_by_id[l[0]] = np.array([_ for j, _ in enumerate(l) if j in column_filter], dtype='float')
            else:
                values_by_id[l[0]] = l[1:][0]

    #for _id in sorted(values_by_id):
    #    print('%s: %s' % (_id, values_by_id[_id]))
    return values_by_id

def fit_model(features_by_id, labels_by_id, model_type):

    num_responders = 0
    num_nonresponders = 0
    num_responders_correct = 0
    num_nonresponders_correct = 0
    num_responders_incorrect = 0
    num_nonresponders_incorrect = 0
    num_incorrect = 0
    num_correct = 0
    num_total = 0
    coefficients = [] 
    tot_coefs = None
    tot_intercepts = None
    # We perform leave-one-out analysis.
    for i, _ in enumerate(sorted(features_by_id)):

        #if _ != '3932511' and _ != '3998062':
        #    continue

        #if _ != '4403148':
        #    continue

        if _ not in labels_by_id:
            continue

        test_x = [features_by_id[_]]
        test_y = [labels_by_id[_]]

        #print('%s: %s\t%s' % (_, features_by_id[_], labels_by_id[_]))

        train_x = []
        train_y = []

        for _id in features_by_id:

            if _id not in labels_by_id:
                continue

            if _ == _id:
                continue

            train_x.append(features_by_id[_id])
            train_y.append(labels_by_id[_id])

        train_x = np.array(train_x)
        train_y = np.array(train_y)
        #print(train_x)
        if model_type == 'lr':
            model = LogisticRegression(C=1e15)#, class_weight={'Non-Responder':8.0/36.0, 'Responder':28.0/36.0})

        if model_type == 'svm':
            model = SVC(kernel='rbf', probability=True)

        if model_type == 'rf':
            model = RandomForestClassifier(n_estimators=100)#, class_weight={'Non-Responder':8.0/36.0, 'Responder':28.0/36.0})
        
        model.fit(train_x, train_y)

        print(model.estimators_)
        if model_type == 'lr':
            coefs = model.coef_[0]
            if tot_coefs is None:
            	tot_coefs = coefs
            	tot_intercepts = model.intercept_
            else:
            	tot_coefs += coefs
            	tot_intercepts += model.intercept_
            coefficients.append(coefs)
        result = model.predict(test_x)
        probabilities = model.predict_proba(test_x)

        num_total += 1
        is_nonresponder = probabilities[0][0] > 0.8

        if is_nonresponder:
            if labels_by_id[_] == 'Non-Responder':
                num_nonresponders += 1
                num_nonresponders_correct += 1
                num_correct += 1
            else:
                num_responders += 1
                num_responders_incorrect += 1
                num_incorrect += 1
        else:
            if labels_by_id[_] == 'Responder':
                num_responders += 1
                num_responders_correct += 1
                num_correct += 1
            else:
                num_nonresponders += 1
                num_nonresponders_incorrect += 1
                num_incorrect += 1

        print('%s\t%s\t%s' % (_, 'Non-responder' if is_nonresponder else 'Responder', probabilities[0]))

    assert(num_nonresponders == num_nonresponders_correct + num_nonresponders_incorrect)
    assert(num_responders == num_responders_correct + num_responders_incorrect)
    assert(num_correct == num_responders_correct + num_nonresponders_correct)
    assert(num_incorrect == num_responders_incorrect + num_nonresponders_incorrect)
    print('Responders Accuracy: %s/%s (%s%%)' % (num_responders_correct, num_responders, (100.0 * num_responders_correct / (num_responders_correct + num_responders_incorrect))))
    print('Non-responders Accuracy: %s/%s (%s%%)' % (num_nonresponders_correct, num_nonresponders, (100.0 * num_nonresponders_correct / (num_nonresponders_correct + num_nonresponders_incorrect))))
    print('Overall Accuracy: %s/%s (%s%%)' % (num_correct, num_total, (100.0 * num_correct / num_total)))

    print('Negative predictive value: %s' % (num_nonresponders_correct / (num_nonresponders_correct + num_responders_incorrect)))
    print('Positive predictive value: %s' % (num_responders_correct / (num_responders_correct + num_nonresponders_incorrect)))
    
    if model_type == 'lr':
        print('\t'.join(str(c) for c in (tot_coefs / len(coefficients))))
        print(tot_intercepts / len(coefficients))

    if model_type == 'rf':
        print(model)

if __name__ == "__main__":
    argument_parser = argparse.ArgumentParser()
    argument_parser.add_argument('-f', '--features_file', type=str, required=True, help='File containing features to be used for fitting. First row should be header, first column should be unique IDs.')
    argument_parser.add_argument('-c', '--feature_columns_to_use', type=str, help='Numeric list of columns in the features list to use, separated by commas (e.g. 2,5,8)')
    argument_parser.add_argument('-l', '--labels_file', type=str, required=True, help='File containing labels to be used for fitting. First row should be header, first column should be unique IDs.')
    argument_parser.add_argument('-m', '--model_type', type=str, choices=['lr', 'svm', 'rf'], required=True, help='The type of model to fit.')
    arguments = argument_parser.parse_args()

    features_by_id = load_tsv(arguments.features_file, arguments.model_type, column_filter=arguments.feature_columns_to_use)
    labels_by_id = load_tsv(arguments.labels_file, array_valued=False, column_filter=None)
    #print(labels_by_id)

    fit_model(features_by_id, labels_by_id, arguments.model_type)
