import numpy as np
import matplotlib.pyplot as plt


E = np.array([-11, -10.997, -0.954, -0.705, -0.607, -0.349, 0.397, 0.553, 0.727, 1.475])
x = np.array([0, 0, 0, 0, 0, -0.7, 0.7, -0.7, 0.7, 0, 0, 0])

Efull = np.array([-11, -10.997, -0.954, -0.705, -0.607, -0.349, -0.349])
xfull = np.array([0, 0, 0, 0, 0, -0.7, 0.7])

x1 = np.array([-0.5,0.5])
x2 = np.array([-1.2,-0.2])
x3 = np.array([0.2,1.2])

offset = 0.15
y1 = xfull - offset
y2 = xfull + offset


txt = np.array([r'$1\Sigma_g^+$', r'$1\Sigma_u^+$', r'$2\Sigma_g^+$', r'$2\Sigma_u^+$', r'$3\Sigma_g^+$',
                r'$\Pi_u$', r'$\Pi_g$', r'$3\Sigma_u^+$', r'$4\Sigma_g^+$', r'$4\Sigma_u^+$'])


up = dict(color='deepskyblue',markersize=10,linestyle='None',marker=r'$\uparrow$',zorder=10)
down = dict(color='deepskyblue',markersize=10,linestyle='None',marker=r'$\downarrow$',zorder=10)

fig, (ax1,ax2) = plt.subplots(2,1, gridspec_kw={'height_ratios': [4,1]}, sharex=True, dpi=160,figsize=(5,7))
ax1.plot(y1,Efull,**up)
ax1.plot(y2,Efull,**down)
ax2.plot(y1,Efull,**up)
ax2.plot(y2,Efull,**down)

a = 0.6
b = 0.04
ax1.annotate(txt[9],(a,E[9]-b),fontsize=14)
ax1.annotate(txt[8],(a,E[8]-b),fontsize=14)
ax1.annotate(txt[7],(a,E[7]-b),fontsize=14)
ax1.annotate(txt[6],(1.3,E[6]-b),fontsize=15)
ax1.annotate(txt[5],(1.3,E[5]-b),fontsize=15)
ax1.annotate(txt[4],(a,E[4]-b+0.03),fontsize=14)
ax1.annotate(txt[3],(a,E[3]-b-0.035),fontsize=14)
ax1.annotate(txt[2],(a,E[2]-b),fontsize=14)
c = 0.0003
ax2.annotate(txt[1],(a,E[1]-c),fontsize=14)
ax2.annotate(txt[0],(a,E[0]-c),fontsize=14)



o = 'tab:orange'
r = 'tab:red'
p = 'tab:pink'
ax1.plot(x1,[E[-1],E[-1]],c=p)
ax1.plot(x1,[E[-2],E[-2]],c=r)
ax1.plot(x1,[E[-3],E[-3]],c=p)
ax1.plot(x2,[E[-4],E[-4]],c=o)
ax1.plot(x3,[E[-4],E[-4]],c=o)
ax1.plot(x2,[E[-5],E[-5]],c=o)
ax1.plot(x3,[E[-5],E[-5]],c=o)
ax1.plot(x1,[E[-6],E[-6]],c=r)
ax1.plot(x1,[E[-7],E[-7]],c=p)
ax1.plot(x1,[E[-8],E[-8]],c=r)

ax2.plot(x1,[E[1],E[1]],c=p)
ax2.plot(x1,[E[0],E[0]],c=r)



pm = 0.004
ax1.set_ylim(-1.1,1.6)
ax2.set_ylim(-10.9985-pm,-10.9985+pm)
ax1.set_xlim(-3,3)

ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.set_xticks([])


d = .15  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

ax1.set_ylabel(r'$E_h$',fontsize=12,rotation=1)


fig.tight_layout()

fig.savefig('MO-diagram.png')



