from sklearn.metrics import precision_recall_fscore_support as score
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import label_binarize
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_curve
from sklearn.multiclass import OneVsRestClassifier
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn import svm
import numpy as np

label_file = "label_feature/c10_unclassified_training_curated_celltype_intl2.txt"
#label_file = "label_feature/c10_celltype_intl2_200.txt"
f=open(label_file,"r")
lines=f.readlines()
label=[]
for line in lines:
        label.append(int(line))
f.close()

classes=set()
for i in range(len(label)):
        classes.add(label[i])
sort_classes=sorted(classes)

bi_label = label_binarize(label, classes=sort_classes)
n_classes = bi_label.shape[1]
y_train = bi_label[0:4362,]
y_test = bi_label[4362:5066,]
#y_train = bi_label[0:100,]
#y_test = bi_label[100:200,]

print(y_train.shape, y_test.shape)

train_file = "word_data/c10_unclassified_training_celltype_2-4gramfeature.txt"
#train_file="word_data/c10_train_celltype_2-4gramfeature.txt"
f = open(train_file,"r")
lines=f.readlines()
row_num=len(lines)
col_num=len(lines[0].split())
predict_score2=[]
for line in lines[0:]:
	tmp = []
	for i in line.split()[0:]:
		tmp.append(float(i))
	predict_score2.append(tmp)
X_train=np.empty((0,col_num), float)
for x in predict_score2:
	X_train = np.append(X_train,np.array([x]),axis=0)

f.close()

test_file = "word_data/c10_unclassified_curated_celltype_2-4gramfeature.txt"
#test_file = "word_data/c10_curated_celltype_2-4gramfeature.txt"
f = open(test_file,"r")
lines=f.readlines()
row_num=len(lines)
col_num=len(lines[0].split())
predict_score2=[]
for line in lines[0:]:
        tmp = []
        for i in line.split()[0:]:
                tmp.append(float(i))
        predict_score2.append(tmp)
X_test = np.empty((0,col_num), float)
for x in predict_score2:
        X_test = np.append(X_test,np.array([x]),axis=0)
f.close()

print(X_train.shape,X_test.shape)

random_state = np.random.RandomState(0)
classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True,
                                 random_state=random_state))
y_score = classifier.fit(X_train, y_train).decision_function(X_test)

precision = dict()
recall = dict()
average_precision = dict()
average_recall = dict()
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
                                                        y_score[:, i])
    average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])

# Compute micro-average ROC curve and ROC area
precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(),
    y_score.ravel())
average_precision["micro"] = average_precision_score(y_test, y_score,
                                                     average="micro")
num=0
# Plot Precision-Recall curve
lw=2
plt.clf()
plt.plot(recall["micro"], precision["micro"], lw=lw, color='navy',
         label='micro-average Precision-recall curve (area = {0:0.2f})'
         ''.format(average_precision["micro"]))
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title('Precision-Recall curve: AUC={0:0.2f}'.format(average_precision["micro"]))
plt.legend(loc="lower left")
figname = "fig"+str(num)
plt.savefig('result-fig/'+figname)
num=num+1

colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal'])

# Plot Precision-Recall curve for each class
plt.clf()
for i, color in zip(range(n_classes), colors):
    plt.plot(recall[i], precision[i], color=color, lw=lw)

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall curve of each class(c10_celltype_2-4gram)')
plt.legend(loc="lower right")
figname = "fig"+str(num)
plt.savefig('result-fig/'+figname)
num=num+1

# Plot Precision-Recall curve for each class
for i, color in zip(range(n_classes), colors):
    plt.clf()
    plt.plot(recall[i], precision[i], color=color, lw=lw,
             label='Precision-recall curve of class {0} (area = {1:0.2f})'
                   ''.format(i, average_precision[i]))

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall curve of 10_celltype_2-4gram')
    plt.legend(loc="lower right")
    figname = "fig"+str(num)
    plt.savefig('result-fig/'+figname)
    num=num+1

print("micro recall:", recall["micro"])
print("micro precision:", precision["micro"])
print("micro average precision:", average_precision["micro"])
print("average precision:", average_precision)
print("precision:", precision)
print("recall:", recall)

f = open('micro_recall', 'w')
f.write(str(recall["micro"]))
f.close()

f = open('micro_precision', 'w')
f.write(str(precision["micro"]))
f.close()

f = open('micro_average_precision', 'w')
f.write(str(average_precision["micro"]))
f.close()

f = open('average_precision', 'w')
f.write(str(average_precision))
f.close()

f = open('precision', 'w')
f.write(str(precision))
f.close()

f = open('recall', 'w')
f.write(str(recall))
f.close()

#confs_matrix=confusion_matrix(true_label, predict_label)

#precision, recall, fscore, support = score(true_label,predict_label)
#average_precision, average_recall, average_fscore, average_support = score(true_label,predict_label,average="micro")
#accuracy = accuracy_score(true_label, predict_label)

#print('accuracy: {}'.format(accuracy))
#print('precision: {}'.format(average_precision))
#print('recall: {}'.format(average_recall))
#print('fscore: {}'.format(average_fscore))
#print('confusion matrix: {}'.format(confs_matrix.shape))
