""" octavePlots module
"""

from math import ceil, log
from geniconfig import Statistics

def generatePlot(output, xdata, series, title, xlabel, ylabel, plotOutput = None, xrange = None, maxY = None, xlogscale=False, ylogscale = False, keyBoxed = True, keyPos = 'right top'):
    for (i, (seriesName, seriesData)) in enumerate(series):
        output.write('%% %s\n' % seriesName)
        header = 'data_%i = [ ' % i
        output.write(header)
        for (j,value) in enumerate(seriesData):
            if j != 0: output.write(',\n' +  (' ' * len(header)))
            output.write("%s" % value)
        output.write('\n];\n\n')

    _writePlotHeader(output     = output,
                     title      = title,
                     xlabel     = xlabel,
                     ylabel     = ylabel,
                     plotOutput = plotOutput,
                     maxY       = maxY,
                     xlogscale  = xlogscale,
                     ylogscale  = ylogscale,
                     keyBoxed   = keyBoxed,
                     keyPos     = keyPos)

    output.write('X = %s;\n' % str(xdata))

    output.write("plot_data = [X' ")
    for i in range(len(series)):
        output.write(", data_%i" % i)
    output.write("];\n\n")

    output.write('__gnuplot_plot__ ')
    for (i, (seriesName, _)) in enumerate(series):
        if i != 0: output.write(', ')
        output.write("plot_data using 1:%i title '%s'" % (i+2, seriesName))
    output.write(';\n')
    output.write("closeplot;\n")


def generateMedianPlot(output, xdata, series, title, xlabel, ylabel, plotOutput = None, xrange = None, xlogscale=False, ylogscale = False, keyBoxed = True, keyPos = 'right top'):
    maxY = 0
    seriesWithMedian = []
    for (seriesName, seriesData) in series:
        seriesWithMedian.append((seriesName, map(lambda vec: 'median(%s)' % vec, seriesData)))
        for values in seriesData:
            # append 0 to the list just to make it not null
            filteredValues = filter(lambda x: x != Statistics.INFINITY, values)
            filteredValues.append(0)
            maxY = max(maxY, max(filteredValues))

    generatePlot(output     = output, 
                 xdata      = xdata,
                 series     = seriesWithMedian, 
                 title      = title,
                 xlabel     = xlabel,
                 ylabel     = ylabel,
                 plotOutput = plotOutput, 
                 xrange     = xrange,
                 maxY       = maxY, 
                 ylogscale  = ylogscale,
                 keyBoxed   = keyBoxed,
                 keyPos     = keyPos)


def generateBoxPlot(output, series, title, xlabel, ylabel, plotOutput = None, keyBoxed = True, keyPos = 'right top', pointSep = 3):
    numberOfSeries = len(series)
    pointWidth = pointSep + numberOfSeries
    numberOfPoints = None
    for (i, (seriesName, seriesData)) in enumerate(series):
        output.write('%% %s\n' % seriesName)
        output.write('data_%i = %s;\n\n' % (i, seriesData))
        if numberOfPoints:
            assert(numberOfPoints == len(seriesData))
        else:
            numberOfPoints = len(seriesData)

    for i in range(numberOfSeries):
        output.write('X_%i = %s;\n' % (i, range(pointSep + i, numberOfPoints * pointWidth, pointWidth)))

    xtics = ''
    for i in range(numberOfPoints):
        if i != 0: xtics = xtics + ', '
        xtics = xtics + '"%s" %i' % (i + 1, pointWidth * i + ceil(pointWidth / 2.0))

    _writePlotHeader(output     = output,
                     plotOutput = plotOutput,
                     title      = title,
                     xlabel     = xlabel,
                     xrange     = (0, numberOfPoints * pointWidth + 1),
                     xtics      = 'nomirror (%s)' % xtics,
                     ylabel     = ylabel,
                     ylogscale  = False,
                     keyBoxed   = keyBoxed,
                     keyPos     = keyPos)
    output.write('__gnuplot_set__ style fill solid border -1;\n')
    output.write('__gnuplot_set__ boxwidth 1\n')
    output.write('__gnuplot_set__ tics out\n\n')

    output.write("plot_data = [")
    for i in range(numberOfSeries):
        if i != 0: output.write(', ')
        output.write("X_%i', data_%i'" % (i,i))
    output.write("];\n\n")

    output.write('__gnuplot_plot__ ')
    for (i, (seriesName, _)) in enumerate(series):
        if i != 0: output.write(', ')
        output.write("plot_data using %i:%i title '%s' with boxes" % (2*i+1, 2*i+2, seriesName))
    output.write(';\n')
    output.write("closeplot;\n")
    output.close()


def _writePlotHeader(output, title, xlabel, ylabel, plotOutput = None, xrange = None, maxY = None, xlogscale = False, ylogscale = False, xtics = -1, keyBoxed = True, keyPos = 'right top'):
    if plotOutput:
        output.write("__gnuplot_set__ term postscript eps color;\n")
        output.write("__gnuplot_set__ output '%s.%s';\n" % (plotOutput, defaultPlotExt))

    output.write("\n")
    output.write("title('%s');\n" % title)
    output.write("xlabel('%s');\n" % xlabel)
    output.write("ylabel('%s');\n" % ylabel)
    if xrange != None:
        output.write("__gnuplot_set__ xrange [%i:%i];\n" % xrange)
    if maxY != None:
        if ylogscale:
            output.write("__gnuplot_set__ yrange [*:%i];\n" % 10**(ceil(log(maxY, 10))))
        else:
            output.write("__gnuplot_set__ yrange [*:%f];\n" % maxY)

    if xtics != -1:
      output.write("__gnuplot_set__ xtics %s;\n" % xtics)
    output.write("__gnuplot_set__ nologscale;\n")
    output.write('__gnuplot_set__ style data linespoints;\n')
    if ylogscale:
        output.write("__gnuplot_set__ logscale y;\n")
    if xlogscale:
        output.write("__gnuplot_set__ logscale x;\n")
    if keyBoxed:
        output.write("__gnuplot_set__ key %s box;\n" % keyPos)
    else:
        output.write("__gnuplot_set__ key %s;\n" % keyPos)
    output.write("\n")

defaultPlotExt = 'eps'
