''' This program produces Figure 18.6 '''

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
#from mpl_toolkits.mplot3d import Axes3D
from python18_2 import fct
from python18_examples import vending_machine

mpl.rcParams['xtick.labelsize'] = 8
mpl.rcParams['lines.linewidth'] = 0.8
mpl.rcParams['lines.markersize'] = 4

x1, x2, y = vending_machine()
Xin = np.concatenate((x1,x2),axis=1)
[X, beta, yhat, ybar, Syy, SSE, SSR, R2] = fct(Xin,y)

z = beta[0] + beta[1]*x1 + beta[2]*x2

fig = plt.figure()
ax = fig.gca(projection='3d')
ax.set_xlim(0,30); ax.set_ylim(0,1500); ax.set_zlim(0,80);
ax.view_init(20,220)

for i in range(25):
    ax.plot((x1[i][0],x1[i][0]), (x2[i][0],x2[i][0]), (y[i][0],z[i][0]),'k')
    if y[i][0] >= z[i][0]:
        ax.plot((x1[i][0],x1[i][0]), (x2[i][0],x2[i][0]), (y[i][0],y[i][0]), '*r')
    else:
        ax.plot((x1[i][0],x1[i][0]), (x2[i][0],x2[i][0]), (y[i][0],y[i][0]), '*b')
        
a = np.array([0, 30, 30, 0, 0])
b = np.array([0, 0, 1500, 1500, 0])
c = np.array([beta[0], beta[0] + beta[1]*30, beta[0] + beta[1]*30 + beta[2]*1500,
      beta[0] + beta[2]*1500, beta[0]]).flatten() 
ax.plot(a,b,c,'k')

#plt.show()

plt.savefig('figure_18.6.pdf',dpi = 600)