#! /usr/bin/env python

#
# this python script is a pyomo-centric translation of the AMPL
# script found at: http://www.ampl.com/NEW/LOOP2/stoch2.run
#

# Python imports
from pyutilib.misc import import_file
from pyomo.environ import *
from pyomo.opt import SolverFactory
from pyomo.opt.base import SolverFactory
from pyomo.opt.parallel import SolverManagerFactory
from pyomo.opt.parallel.manager import solve_all_instances

# import the master a
# initialize the master instance.
mstr_mdl = import_file("master.py").model
mstr_inst = mstr_mdl.create_instance("master.dat")

# initialize the sub-problem instances.
sb_mdl = import_file("subproblem.py").model
sub_insts = []
sub_insts.append(
    sb_mdl.create_instance(name="Base Sub-Problem", \
                           filename="base_subproblem.dat"))
sub_insts.append(
    sb_mdl.create_instance(name="Low Sub-Problem", \
                           filename="low_subproblem.dat"))
sub_insts.append(
    sb_mdl.create_instance(name="High Sub-Problem", \
                           filename="high_subproblem.dat"))

# initialize the solver manager.
solver_manager = SolverManagerFactory("serial")

# miscellaneous initialization.
mstr_inst.Min_Stage2_Profit = float("Inf")

gap = float("Inf")
max_iterations = 50
# the main benders loop.
for i in range(1, max_iterations+1):

    print("\nIteration= %d" % (i))

    # solve the subproblems.
    solve_all_instances(solver_manager, 'cplex', sub_insts)
    for instance in sub_insts:
        print("Profit for scenario="+instance.name+" is "+
              str(round(instance.Exp_Stage2_Profit(), 4)))
    print("")

    # if not converged, add store the pricing information from the
    # sub-problem solutions in the master.
    mstr_inst.CUTS.add(i)
    for s, inst in enumerate(sub_insts, 1):
        for t in mstr_inst.TWOPLUSWEEKS:
            mstr_inst.time_price[t,s,i] = \
                inst.dual[inst.Time[t]]
        for p in mstr_inst.PROD:
            mstr_inst.bal2_price[p,s,i] = \
                inst.dual[inst.Balance2[p]]
            for t in mstr_inst.TWOPLUSWEEKS:
                mstr_inst.sell_lim_price[p,t,s,i] = \
                    inst.urc[inst.Sell[p,t]]

    # add the master cut.
    cut = sum((mstr_inst.time_price[t,s,i] * \
                mstr_inst.avail[t])
               for t in mstr_inst.TWOPLUSWEEKS
               for s in mstr_inst.SCEN)
    cut += sum((mstr_inst.bal2_price[p,s,i] * \
                 (-mstr_inst.Inv1[p]))
                for p in mstr_inst.PROD
                for s in mstr_inst.SCEN)
    cut += sum((mstr_inst.sell_lim_price[p,t,s,i] * \
                mstr_inst.market[p,t])
               for p in mstr_inst.PROD
               for t in mstr_inst.TWOPLUSWEEKS
               for s in mstr_inst.SCEN)
    mstr_inst.Cut_Defn.add(mstr_inst.Min_Stage2_Profit <= cut)

    # compute expected second-stage profit
    Exp_Stage2_Profit = sum(inst.Exp_Stage2_Profit()
                            for inst in sub_insts)
    print("Expected Stage2 Profit= "+str(round(Exp_Stage2_Profit, 4)))
    print("")

    newgap = round(mstr_inst.Min_Stage2_Profit.value - \
                   Exp_Stage2_Profit, 6)
    if newgap == 0:
        # get rid -0.0, which makes this script easier
        # to test against a baseline
        newgap = 0
    print("New gap= "+str(newgap)+"\n")

    if newgap > 0.00001:
        gap = min(gap, newgap)
    else:
        print("Benders converged!")
        break

    # re-solve the master and update the subproblem inv1 values.
    solve_all_instances(solver_manager, 'cplex', [mstr_inst])

    print("Master expected profit="+str(mstr_inst.Expected_Profit()))

    for instance in sub_insts:
        for p in mstr_inst.PROD:
            # the master inventory values might be slightly
            # less than 0 (within tolerance); threshold here.
            instance.inv1[p] = max(mstr_inst.Inv1[p](),0.0)
else:
    # this gets executed when the loop above does not break
    print("Maximum Iterations Exceeded")

print("\nConverged master solution values:")
for p in sorted(mstr_inst.PROD):
    print("Make1["+p+"]="+str(round(mstr_inst.Make1[p](), 4)))
    print("Sell1["+p+"]="+str(round(mstr_inst.Sell1[p](), 4)))
    print("Inv1["+p+"]="+str(round(mstr_inst.Inv1[p](), 4)))
