library(ggplot2) 
library(ggrepel)

dir<-"word_data/"
type<-"celltype" #celltype or antigen
label<-"2" # "1" or "2"
num<-"2-4" # "1" to "10"
filt<-"10" # "1" to "5"

acc_mat<-c()
if(type=="celltype")
{
  class_num=133 #252 #783 #87
  test_num=704
}else
{
  class_num=128  #283 #748 #76
  test_num=1360
}
true_class_total<-array(0, dim=class_num)
false_class_total<-array(0, dim=class_num)
result_file=paste(dir,"output_c",filt,"_cost1_unclassified_curated_",type,"_intl",label,"f",num,"gram.txt", sep="")
test_label_file=paste("label_feature/c",filt,"_unclassified_curated_",type,"_intl",label,".txt",sep = "")
train_label_file=paste("label_feature/c",filt,"_unclassified_training_",type,"_intl",label,".txt",sep = "")

test_result=read.table(file = result_file)
test_true_label=read.table(file = test_label_file)
train_true_label=read.table(file = train_label_file)
test_result.df <- data.frame(test_result)

predict_label=data.frame(test_result.df[2:(test_num+1),1])
true_predict_label<-cbind(test_true_label,predict_label)
colnames(true_predict_label)<-c("true","predicted")

#write index of false sample to file
predict_file =  paste(dir,"c",filt,"_cost1_unclassified_curated_",type,"_",num,"gram","_labelcheck",sep = "")
write.table(true_predict_label, file = predict_file,col.names = F, row.names = F, quote = F)

false_index<-c() # record the index of false predicted samples
true_index<-c()

index_unclassified<-sort(table(train_true_label),decreasing=TRUE)[1] #index that most frequently appears
index_unclassified<-as.integer(names(index_unclassified))

for (row in 2:dim(test_result.df)[1]) {
    row_value<-test_result.df[row,2:dim(test_result.df)[2]]
    sort_value<-sort(row_value,decreasing = T)
    top<-sort_value[1:1]  #top1=sort_value[1:1], top3=sort_value[1:3]
    col_ind_top<-which(colnames(row_value) %in% colnames(top))
    class_ind_top<-test_result.df[1,col_ind_top+1]
    label_ind_top<-true_predict_label[row-1,1]
    
    if((index_unclassified %in% class_ind_top)&&!(true_predict_label[row-1,1] %in% train_true_label[,1])){ 
      #predicted as unclassified and the true label of it is not contained in training labels, considered as correctly predicted
      true_class_total[label_ind_top+1]=true_class_total[label_ind_top+1]+1
      true_index<-c(true_index,row-1)
    }else{
        if(true_predict_label[row-1,1] %in% class_ind_top) {
          true_class_total[label_ind_top+1]=true_class_total[label_ind_top+1]+1
          true_index<-c(true_index,row-1)
        }
        else{
          false_class_total[label_ind_top+1]=false_class_total[label_ind_top+1]+1
          false_index<-c(false_index,row-1)
        }
    }
}

false_file = paste(dir,"c",filt,"_cost1_",type,"_curated_l",label,"_",num,"gram","_falseindex",sep = "")
write.table(false_index, file = false_file,col.names = F, row.names = F, quote = F, sep = "\n")
true_file = paste(dir,"c",filt,"_cost1_",type,"_curated_l",label,"_",num,"gram","_trueindex",sep = "")
write.table(true_index, file = true_file,col.names = F, row.names = F, quote = F, sep = "\n")
  
  
#write accuracy of each class to file
filename = paste(dir,"c",filt,"_cost1_",type,"_curated_l",label,"_",num,"gram","_accuracy",sep = "")
acc<-c()
acc<-c(acc, sum(true_class_total)/(sum(true_class_total)+sum(false_class_total)))
for (i in 1:dim(true_class_total)) {
  if(true_class_total[i]==0)
    acc<-c(acc,0)
  else
    acc<-c(acc, true_class_total[i]/(true_class_total[i]+false_class_total[i]))
}
write.table(acc, file = filename,col.names = F, row.names = F, quote = F, sep = "\n")
acc_mat<-rbind(acc_mat,acc)
acc_mat[1,]



##----------------plot Mean accuracy and error bar of each class for 10-fold cv----------##

mean<-c()
sd<-c()
for (c in 1:dim(acc_mat)[2]) {
  mean <- c(mean,mean(acc_mat[,c]))
  sd <- c(sd,sd(acc_mat[,c]))
}

##name each column as class
l2_class<-c("Average")
class_file<-paste("label_feature/c10_",type,"_training_l",label,"_class.txt",sep = "")
class1<-read.table(file = class_file, sep="\t")
col.df<-data.frame(class1[,2])
colnames(col.df)<-c("V1")
dim(l2_class)
col.df
mean
l2_class 
l2_class<-rbind(t(l2_class), col.df)
l2_class<-cbind(l2_class,mean)
l2_class<-cbind(l2_class,sd)

l2_class

colnames(l2_class) <- c("Class", "MeanAccuracy", "SD")
dir_out = paste(dir,sep = "")
fo_img = sprintf("%sc10_cost%s_%s_training_l%s_class_%sgram_acc.pdf", dir_out, cost,type, label,num)
t=paste("Average accuracy of each class (",dim(class1)[1]," classes in Label ",label,") for human ",type," data
        Total average is ", l2_class[1,2],"(word, ",num,"gram, cost",cost,")",sep = "")

pdf(fo_img, width=20, height=9, pointsize=9)
g <- ggplot(l2_class, aes(x=Class, y=MeanAccuracy, colour=Class)) + 
  geom_errorbar(aes(ymin=MeanAccuracy-SD, ymax=MeanAccuracy+SD), width=.5) +
  geom_line() + geom_point() + 
  theme(axis.text.x = element_text(angle = 90, hjust = 1), 
        plot.title = element_text(family = 'Helvetica', 
                                  color = '#666666', face = 'bold', size = 16, hjust = 0.5))+labs(title = t)
plot(g)
dev.off()

#extract bad cases
filename=paste("~/Documents/Automatic-Annotation/Analyze/cost",cost,"_",type,"_training_l",label,"_",num,"gram","_badcase08_name.txt",sep = "")
l2_class[,2]
ind<-which(l2_class[,2]<0.8)
ind
badcase<-l2_class[ind,1]
badcase
vec_badcase<-as.vector(badcase)
write(vec_badcase,file=filename,sep = "\n")
badcase
class1[,2]
bad_class_id<-c()
for (i in 1:length(vec_badcase)) {
  id<-which(class1[,2]==vec_badcase[i])
  bad_class_id<-c(bad_class_id,id-1)
}
bad_class_id #exactly class ID (0-86)

false_class_total
true_class_total
acc<-c()
for (i in 1:length(true_class_total)) {
  class<-true_class_total[i]/(false_class_total[i]+true_class_total[i])
  acc<-c(acc,class)
}
id<-which(acc<0.8)

sum(false_class_total)
sum(true_class_total)

