from sklearn.metrics import precision_recall_fscore_support as score
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
from sklearn.multiclass import OneVsRestClassifier
import matplotlib.pyplot as plt
from itertools import cycle
import warnings
from scipy import interp
from sklearn import svm
import numpy as np
import sys

args = sys.argv
if(len(args)<4):
	print("Usage: python3 %s celltype/antigen gram_n(1 to 10) filter_n(1-10)" % args[0])

else:
	label_file = "label_feature/c"+args[3]+"_unclassified_training_curated_"+args[1]+"_intl2.txt"
	#label_file = "label_feature/c10_celltype_intl2_200.txt"
	f=open(label_file,"r")
	lines=f.readlines()
	label=[]
	for line in lines:
		label.append(int(line))
	f.close()

	classes=set()
	for i in range(len(label)):
		classes.add(label[i])
	sort_classes=sorted(classes)

	bi_label = label_binarize(label, classes=sort_classes)
	n_classes = bi_label.shape[1]
	
	if(args[1]=="celltype"):
		y_train = bi_label[0:4362,]
		y_test = bi_label[4362:5066,]
	if(args[1]=="antigen"):
		y_train = bi_label[0:6953,]
		y_test = bi_label[6953:8313,]
	#y_train = bi_label[0:100,]
	#y_test = bi_label[100:200,]

	print(y_train.shape, y_test.shape)

	train_file = "word_data/c"+args[3]+"_unclassified_training_"+args[1]+"_"+args[2]+"gramfeature.txt"
	#train_file="word_data/c10_train_celltype_2-4gramfeature.txt"
	f = open(train_file,"r")
	lines=f.readlines()
	row_num=len(lines)
	col_num=len(lines[0].split())
	predict_score2=[]
	for line in lines[0:]:
		tmp = []
		for i in line.split()[0:]:
			tmp.append(float(i))
		predict_score2.append(tmp)
	X_train=np.empty((0,col_num), float)
	for x in predict_score2:
		X_train = np.append(X_train,np.array([x]),axis=0)

	f.close()

	test_file = "word_data/c"+args[3]+"_unclassified_curated_"+args[1]+"_"+args[2]+"gramfeature.txt"
	#test_file="word_data/c10_curated_celltype_2-4gramfeature.txt"
	f = open(test_file,"r")
	lines=f.readlines()
	row_num=len(lines)
	col_num=len(lines[0].split())
	predict_score2=[]
	for line in lines[0:]:
        	tmp = []
        	for i in line.split()[0:]:
                	tmp.append(float(i))
        	predict_score2.append(tmp)
	X_test = np.empty((0,col_num), float)
	for x in predict_score2:
        	X_test = np.append(X_test,np.array([x]),axis=0)
	f.close()

	print(X_train.shape,X_test.shape)

	classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True,
                                 random_state=0))
	clf=classifier.fit(X_train, y_train)
	y_score = clf.decision_function(X_test)
	y_predict = clf.predict(X_test)

	print("score: ", type(y_score), y_score)
	print("predict: ", type(y_predict), y_predict)
	np.savetxt('yscore_c'+args[3]+'_'+args[1]+'_'+args[2]+'gram.txt',y_score, fmt='%-7.4f')
	np.savetxt('ypred_c'+args[3]+'_'+args[1]+'_'+args[2]+'gram.txt',y_predict, fmt='%-7.4f')

# Compute ROC curve and ROC area for each class
	fpr = dict()
	tpr = dict()
	roc_auc = dict()

	for i in range(n_classes):
		fpr[i], tpr[i], _  = roc_curve(y_test[:, i], y_score[:, i])
		roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area
	fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
	roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

	print("micro fpr:", fpr["micro"])
	print("micro tpr:", tpr["micro"])

# Compute macro-average ROC curve and ROC area

# First aggregate all false positive rates
	all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

# Then interpolate all ROC curves at this points
	mean_tpr = np.zeros_like(all_fpr)
	for i in range(n_classes):
		mean_tpr += interp(all_fpr, fpr[i], tpr[i])

# Finally average it and compute AUC
	mean_tpr /= n_classes

	fpr["macro"] = all_fpr
	tpr["macro"] = mean_tpr
	roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])


	print("macro fpr:", fpr["macro"])
	print("macro tpr:", tpr["macro"])

	colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal'])
	lw=2

# Plot all ROC curves
	plt.figure()
	lw=2
	plt.plot(fpr["micro"], tpr["micro"],
         label='micro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["micro"]),
         color='deeppink', linestyle=':', linewidth=4)

	plt.plot(fpr["macro"], tpr["macro"],
         label='macro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["macro"]),
         color='navy', linestyle=':', linewidth=4)

	plt.plot([0, 1], [0, 1], 'k--', lw=lw)
	plt.xlim([0.0, 1.0])
	plt.ylim([0.0, 1.05])
	plt.xlabel('False Positive Rate')
	plt.ylabel('True Positive Rate')
	plt.title('micro-averaging ROC curve to c'+args[3]+'_'+args[1]+'_'+args[2]+'gram')
	plt.legend(loc="lower right")
	figname = "fig_"+args[1]+"_allclass_roc"
	plt.savefig('result-fig/'+figname)

# Plot ROC curve for each class
	plt.clf()
	for i, color in zip(range(n_classes), colors):
		plt.plot(fpr[i], tpr[i], color=color, lw=lw)
	plt.xlim([0.0, 1.0])
	plt.ylim([0.0, 1.05])
	plt.xlabel('False Positive Rate')
	plt.ylabel('True Positive Rate')
	plt.title('ROC curve of c'+args[3]+'_'+args[1]+'_'+args[2]+'gram')
	plt.legend(loc="lower right")
	figname = "fig_"+args[1]+"_eachclass_roc"
	plt.savefig('result-fig/'+figname)


# Plot ROC curve for each class
	num=1
	for i, color in zip(range(n_classes), colors):
		plt.clf()
		plt.plot(fpr[i], tpr[i], color=color, lw=lw,label='ROC curve of class {0} (area = {1:0.2f})'
		''.format(i, roc_auc[i]))

		plt.xlim([0.0, 1.0])
		plt.ylim([0.0, 1.05])
		plt.xlabel('False Positive Rate')
		plt.ylabel('True Positive Rate')
		plt.title('ROC curve of 10_antigen_2-4gram')
		plt.legend(loc="lower right")
		figname = "fig_"+args[1]+"_roc"+str(num)
		plt.savefig('result-fig/'+figname)
		num=num+1

	f=open('micro_fpr','w')
	f.write(str(fpr["micro"]))
	f.close()
	f=open('micro_tpr','w')
	f.write(str(tpr["micro"]))
	f.close()

	f=open('macro_fpr','w')
	f.write(str(fpr["macro"]))
	f.close()
	f=open('macro_tpr','w')
	f.write(str(tpr["macro"]))
	f.close()
