##########################################################################################
#
#     MODEL FITTING AND PREDICTING
#
##########################################################################################
library(PDwCI)
library(randomForest)
library(dplyr)
library(plyr)
library(ggplot2)
source("Rscript/functionsNeeded.R")
source("Rscript/correlations.R")
pred_sel3 <-  c( "Income","Energy.Generation","W.BS", "average_gsp","Coastal",
                 "Urbanization","CDD","HDD","State_Area","FarmlandPercentage", "CoalProd")

set.seed(659)
masterlist_ <- read.csv("masterlist_200819.csv")
row.names(masterlist_) <- paste0(substr(masterlist_$YEAR,3,4),masterlist_$STATE)
percapitaData <- masterlist_[,names(masterlist_) %in% c(pred_sel3, "YEAR","To_waterPercapita")]
percapitaData_training <- subset(percapitaData, YEAR < 2015)
percapitaData_test <- subset(percapitaData, YEAR == 2015)

#RF withOUT SPI Drought and Soil Moisture
RF.percapita <- randomForest(To_waterPercapita~., data = percapitaData_training, importance = T)
RF.R2.RMSE <- rSQRF(RF.percapita, percapitaData_training, "To_waterPercapita",0); RF.R2.RMSE
RF.CV.RMSE <- rf_kfold_CV(percapitaData_training,40,"To_waterPercapita", "rf"); 
RF.CV.RMSE
RF.R2.RMSE.test <- rSQRF(RF.percapita, percapitaData_test, "To_waterPercapita",0)

## training error AND cross validation error
sprintf("Random Forest perCapita (training): Rsq: %.3f, RMSE: %.3f, LOOCV RMSE: %.3f",RF.R2.RMSE[1],RF.R2.RMSE[2],RF.CV.RMSE)

## predict error
sprintf("Random Forest perCapita (predicting): Rsq: %.3f, RMSE: %.3f",RF.R2.RMSE.test[1],RF.R2.RMSE.test[2])

##########################################################################################
#
#     PLOTTING
#
##########################################################################################

plottingNames <- read.csv("../Data/Plotting labels.csv")
today <- format(Sys.Date(),'%d%m%y')
#plot of actual vs predicted
tiff(paste0("ActualvsFit_",today,".tiff"),width = 8, height = 5, units = 'in',res=400, family = "Comic Sans")
plot(percapitaData$To_waterPercapita,predict(RF.percapita,percapitaData),col='blue',main='Plot of Actual vs Predicted \nTotal Water Usage per Capita (mil gal/person)',ylim = c(0,15),xlim = c(0,15),ylab="Actual",xlab="Predicted",pch=18,cex=0.7) #
abline(lm(predict(RF.percapita, percapitaData)~percapitaData$To_waterPercapita))
dev.off()

n <- 10
PDP_cluster(RF.percapita, plottingNames, percapitaData_test, n,"Total Water Usage PerCapita ","To_waterPercapita",paste0(today,"_"))

##########################################################################################
#
#     STATE HOLD OUT
#
##########################################################################################
state_holdout <- data.frame()

for(i in 10:20){
  cat(i)
  for(j in 1:100){ cat(".")
    #randomly picked 10-20 states to be hold out. this includes all years of that state
    state_to_hold <- sample(unique(as.character(masterlist_$STATE)),i)
    sprintf("No of states hold: %i", length(state_to_hold))
    percapitaData_training <- subset(percapitaData, !substr(row.names(percapitaData),3,4) %in% state_to_hold)
    percapitaData_test <- subset(percapitaData, substr(row.names(percapitaData),3,4) %in% state_to_hold)
    
    #RF withOUT SPI Drought and Soil Moisture
    RF.percapita <- randomForest(To_waterPercapita~., data = percapitaData_training, importance = T)
    RF.R2.RMSE <- rSQRF(RF.percapita, percapitaData_training, "To_waterPercapita",0); #RF.R2.RMSE
    RF.CV.RMSE <- rf_kfold_CV(percapitaData_training,40,"To_waterPercapita", "rf"); 
    #RF.CV.RMSE
    RF.R2.RMSE.test <- rSQRF(RF.percapita, percapitaData_test, "To_waterPercapita",0)
    
    tmp <- data.frame(no_states = i, holdout_state = paste(state_to_hold, collapse = ", "), 
                      train_rmse = RF.R2.RMSE[2], train_r2 = RF.R2.RMSE[1], 
                      cv_rmse = RF.CV.RMSE,  test_rmse = RF.R2.RMSE.test[2])
    state_holdout <-rbind(state_holdout, tmp)
  }
}
write.csv(state_holdout, "state_holdout.csv", row.names = F)

#plotting results
state_holdout_summary <- data.frame(no_states = 10:20, N = 100)
for(i in 3:6){
  tmp <- summarySE(state_holdout, measurevar=names(state_holdout[i]), groupvars=c("no_states"))
  colnames(tmp)[4:6] <- paste0(colnames(tmp)[3],"_",colnames(tmp)[4:6])
  state_holdout_summary <- cbind(state_holdout_summary,tmp[,3:6])
}


# Use 95% confidence interval instead of SEM
# The errorbars overlapped, so use position_dodge to move them horizontally
pd <- position_dodge(0.3) # move them .05 to the left and right
p <- ggplot(state_holdout_summary, aes(x=no_states, y=cv_rmse, colour=no_states)) + 
  geom_errorbar(aes(ymin=cv_rmse-ci, ymax=cv_rmse+ci), width=.1, position=pd) +
  geom_line(position=pd) +
  geom_point(position=pd) + xlim(10,20) + ylim(0,2)
  
p <- ggplot(state_holdout_summary, aes(x = no_states, fill = "Legend")) + #
      geom_line(aes(y = train_rmse, colour = "Training RMSE"), position=pd) + 
      geom_errorbar(aes(ymin=(train_rmse-train_rmse_se), ymax=(train_rmse+train_rmse_se), colour = "Training RMSE"), width=.1, position=pd) +
      geom_point(aes(y = train_rmse), position=pd, size=2, shape=21, fill="white") + 
      #
      geom_line(aes(y = cv_rmse, colour = "LOOCV RMSE"), position=pd) +  
      geom_errorbar(aes(ymin=cv_rmse-cv_rmse_se, ymax=cv_rmse+cv_rmse_se, colour = "LOOCV RMSE"), width=.1, position=pd) +
      geom_point(aes(y = cv_rmse), position=pd, size=2, shape=21, fill="white") +
      #
      geom_line(aes(y = test_rmse, colour = "Test RMSE"), position=pd) + 
      geom_errorbar(aes(ymin=(test_rmse-test_rmse_se), ymax=(test_rmse+test_rmse_se), colour = "Test RMSE"), width=.1, position=pd) +
      geom_point(aes(y = test_rmse), position=pd, size=2, shape=21, fill="white") + 
      #
      xlim(10,20) + expand_limits(y=1.3) +
      scale_x_continuous(breaks = 9:21) +
      xlab("Number of States Hold Out") + ylab("RMSE") + ggtitle("RMSE by Number of States Hold Out") + labs(colour = "Legend")
p

q <- ggplot(state_holdout_summary, aes(x = no_states, fill = "Legend")) + #
  geom_line(aes(y = train_r2, colour = "Training RMSE"), position=pd) + 
  geom_errorbar(aes(ymin=(train_r2-train_r2_se), ymax=(train_r2+train_r2_se), colour = "Training RMSE"), width=.1, position=pd) +
  geom_point(aes(y = train_r2), position=pd, size=2, shape=21, fill="white") +
  xlim(10,20) + expand_limits(y=1) +
  scale_x_continuous(breaks = 9:21) +
  xlab("Number of States Hold Out") + ylab("In-Sample R-sq") + ggtitle("In-Sample R-sq by Number of States Hold Out") + labs(colour = "Legend")
q






