///
/// This file is part of Rheolef.
///
/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>
///
/// Rheolef is free software; you can redistribute it and/or modify
/// it under the terms of the GNU General Public License as published by
/// the Free Software Foundation; either version 2 of the License, or
/// (at your option) any later version.
///
/// Rheolef is distributed in the hope that it will be useful,
/// but WITHOUT ANY WARRANTY; without even the implied warranty of
/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
/// GNU General Public License for more details.
///
/// You should have received a copy of the GNU General Public License
/// along with Rheolef; if not, write to the Free Software
/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
/// 
/// =========================================================================
#include "rheolef/error-estimator-zz.h"
#include "rheolef/form.h"
#include "rheolef/geo-connectivity.h"
#include "rheolef/ublas-invert.h"
#include "rheolef/piola_algo_v1.h"
#include "rheolef/riesz_representer.h"
#include "rheolef/exact_compose.h"
#include "rheolef/field-local-norm.h"
#include "rheolef/ublas-io.h"
using namespace std;
namespace rheolef { 
// --------------------------------------------------------------------------
// Part I. L2 projection post-processing
// --------------------------------------------------------------------------
// Rh = exact L2 projection
// return Rh(grad(uh)) in Pk when grad_uh in P(k-1)d and k=1,2
// => for P2 requires a matrix resolution
field postprocessing_prol2 (const field& grad_uh) {
    const geo& omega = grad_uh.get_geo();
    string grad_approx = grad_uh.get_approx();
    string approx = (grad_approx == "P0") ? "P1" : "P2"; 
    const space& Lkm1h = grad_uh.get_space();
    space Hkh   (omega, approx, "vector"); 
    form proj (Lkm1h, Hkh, "mass");
    field Rh_grad_uh (Hkh);
    if (approx != "P1") {
      form m    (Hkh, Hkh, "mass");
      ssk<Float> fact_m = ldlt(m.uu);
      Rh_grad_uh.u = fact_m.solve((proj*grad_uh).u);
    } else {
      form_diag m (Hkh, "mass");
      Rh_grad_uh = (1.0/m)*(proj*grad_uh);
    }
    return Rh_grad_uh;    
}
// Qh = inexact L2 projection, with quadrature formulae
// return Qh(grad(uh)) in Pk when grad_uh in P(k-1)d and k=1,2
// may be exact with a sufficient quadrature order
// NOTE: the choice of quadrature node on superconvergent points
//  does not leads to a good L2 projection with P2
// and this method is not usefull
field postprocessing_prol2 (const field& grad_uh, quadrature_option qopt) {
    const geo& omega = grad_uh.get_geo();
    size_t d = omega.dimension();
    string grad_approx = grad_uh.get_approx();
    string approx = (grad_approx == "P0") ? "P1" : "P2"; 
    space Hkh   (omega, approx, "vector"); 
    field proj_grad_uh = riesz_representer (Hkh, grad_uh, qopt);
    field Qh_grad_uh (Hkh);
    if (approx != "P1") {
      form m    (Hkh, Hkh, "mass");
      bool wild_lumping = false; // fails on triangles ! seems ok in 1d and 3d...
      if (wild_lumping) {
        field one (Hkh, 1.0);
        form_diag inv_md = 1.0/form_diag(m*one);
        Qh_grad_uh = inv_md*(proj_grad_uh);
      } else {
        ssk<Float> fact_m = ldlt(m.uu);
        Qh_grad_uh.u = fact_m.solve(proj_grad_uh.u);
      }
    } else {
      form_diag m (Hkh, "mass");
      Qh_grad_uh = (1.0/m)*(proj_grad_uh);
    }
    return Qh_grad_uh;    
}
// --------------------------------------------------------------------------
// Part II. Zienkiewics-Zhu projection-like post-processing
// --------------------------------------------------------------------------
// when vertice is on the boundary, the macro_element
// arround this vertice is too small and the least-square problem then is singular.
// => extend by merging with surrounding elements
void
extend_macro_element (
  const geo& omega,
  const vector<set<geo::size_type> >& macro_element,
  set<geo::size_type>& M
)
{
  set<geo::size_type> old_M;
  copy (M.begin(), M.end(), inserter(old_M, old_M.end()));
  for (set<geo::size_type>::const_iterator p = old_M.begin(); p != old_M.end(); p++) {
    const geo_element& K = omega.element(*p);
    for (size_t iloc = 0; iloc < K.size(); iloc++)  {
      size_t i = K[iloc];
      // M := M union M[i]
      in_place_set_union (M, macro_element[i]);
    }
  }
}
// ---------------------------------------------
// compute the bounding box of the macro-element
// ---------------------------------------------
void
macro_element_bbox (
  const geo& omega,
  const set<geo::size_type>& M,
  point bbox[4])
{
  size_t d = omega.dimension();
  point xmin, xmax;
  for (geo::size_type j = 0; j < d; j++) {
    xmin[j] =  numeric_limits<Float>::max();
    xmax[j] = -numeric_limits<Float>::max();
  }
  for (geo::size_type j = d+1; j < 3; j++) {
    xmin [j] = xmax [j] = 0;
  }
  for (set<geo::size_type>::const_iterator p = M.begin(); p != M.end(); p++) {
    const geo_element& K = omega.element(*p);
    for (size_t iloc = 0; iloc < K.size(); iloc++)  {
      size_t i = K[iloc];
      const point& xi = omega.vertex(i);
      for (geo::size_type j = 0 ; j < d; j++) {
        xmin[j] = ::min(xi[j], xmin[j]);
        xmax[j] = ::max(xi[j], xmax[j]);
      }
    }
  }
  bbox[0] = xmin;
  for (size_t l = 0; l < d; l++) {
    bbox[l+1]    = xmin;
    bbox[l+1][l] = xmax[l];
  }
}
// ----------------------------------
// macro_element piola transformation
// M is a surrounding simplex
// based on the bounding box
// -> adimensionalize ccords in the
//    reference element tilde_M
//    to avoid floating pbs when
//    polynom evaluation
// ----------------------------------
point
macro_element_piola (
  const reference_element& tilde_M,
  point p[4],
  const point& hat_x)
{
  switch (tilde_M.variant()) {
    case reference_element::e: return piola_e (hat_x, p[0], p[1]);
    case reference_element::t: return piola_t (hat_x, p[0], p[1], p[2]);
    default:                   return piola_T (hat_x, p[0], p[1], p[2], p[3]);
  }
}
point
macro_element_inv_piola (
  const reference_element& tilde_M,
  point p[4],
  const point& x)
{
  switch (tilde_M.variant()) {
    case reference_element::e: return inv_piola_e (x, p[0], p[1]);
    case reference_element::t: return inv_piola_t (x, p[0], p[1], p[2]);
    default:                   return inv_piola_T (x, p[0], p[1], p[2], p[3]);
  }
}
// ---------------------------------------------------------
// compute g(x) = sum_j g_j*phi_j(x) when x in macro-element
// ---------------------------------------------------------
point
macro_element_vector_evaluate (
  const point& x,
  const reference_element& tilde_M,
  point bbox[4],
  const set<geo::size_type>& M,
  const ublas::vector<point>& g,
  const basis& phi,
  bool is_singular,
  const field& grad_uh,
  std::vector<Float>& phi_x) // working array, size=nrow
{
  if (is_singular) {
    // why ? patch too small ?
    // -> interpolate grad_uh at xi
#define HAVE_BUG_IN_LOCALIZE
#ifdef HAVE_BUG_IN_LOCALIZE
    // use any element and interpolate
    size_t K_idx = *(M.begin());
    const geo_element& K = grad_uh.get_geo().element(K_idx);
    meshpoint hat_x = grad_uh.get_space().hatter (x,K.index());
    point gx;
    grad_uh.evaluate (hat_x, gx);
    return gx;
#else // ! HAVE_BUG_IN_LOCALIZE
    // localize is buggy in 2d on small mesh 
    // and not available in 3d
    return grad_uh.vector_evaluate (x);
#endif // ! HAVE_BUG_IN_LOCALIZE
  }
  // compute g(x) = sum_j g_j*phi_j(x)
  point tilde_x = macro_element_inv_piola (tilde_M, bbox, x);
  phi.eval (tilde_M, tilde_x, phi_x);
  point gx (0.0,0.0,0.0);
  for (size_t j = 0; j < g.size(); j++) {
    gx = gx + phi_x[j]*g[j];
  }
  return gx;
}
// ----------------------------------
// Rh = local SPR ZZ projection
// return Rh(grad(uh)) in Pk when grad_uh in P(k-1)d and k=1,2
// ----------------------------------
field postprocessing_macro (const field& grad_uh) {
warning_macro ("post macro");
  const geo& omega = grad_uh.get_geo();
  size_t d = omega.dimension();
  string grad_approx = grad_uh.get_approx();
  string approx = (grad_approx == "P0") ? "P1" : "P2"; 
  const space& Lkm1h = grad_uh.get_space();
  space Hkh_scalar (omega, approx);
  space Hkh        (omega, approx, "vector");
  size_t k = Hkh_scalar.degree();
  size_t ndof = Hkh_scalar.size();
  field wh (Hkh_scalar, 0.0);
  field rh (Hkh, 0.0);
  quadrature_option qopt;
  qopt.set_family (quadrature_option::superconvergent);
  qopt.set_order  (Hkh.degree() );
  quadrature<Float> quad (qopt);
  basis phi = Hkh_scalar.get_basis();
  piola_on_quadrature piola (quad, Hkh_scalar);
  // advanced connectivity
  vector<set<geo::size_type> > macro_element (omega.n_vertex());
  build_point_to_element_sets (omega.begin(), omega.end(), macro_element.begin());
  vector<geo_element> edge;
  vector<set<geo::size_type> > macro_edge;
  if (k >= 2 && d >= 2) { // when d=1 : can use element list !
    edge.resize (omega.n_edge());
    build_edge (omega, edge);
    macro_edge.resize (omega.n_vertex());
    build_point_to_element_sets (edge.begin(), edge.end(), macro_edge.begin());
  }
  reference_element tilde_M;
  switch (d) {
   case 1:  tilde_M.set_variant (reference_element::e); break;
   case 2:  tilde_M.set_variant (reference_element::t); break;
   default: tilde_M.set_variant (reference_element::T); break;
  }
  size_t nrow = phi.size (tilde_M);
  std::vector<Float> phi_xq (nrow); // working arrays: can be the same one
  std::vector<Float> phi_xi (nrow);
  std::vector<Float> phi_xe (nrow);
  size_t n_singular = 0;
  for (size_t i = 0; i < omega.n_vertex(); i++) {
    const point& xi = omega.vertex(i);
    // step 1: count
    set<geo::size_type> Mi;
    copy (macro_element[i].begin(), macro_element[i].end(), inserter(Mi,Mi.end()));
    vector<Float> meas_K (Mi.size());
    size_t ncol = 0;
    for (set<geo::size_type>::const_iterator p = Mi.begin(); p != Mi.end(); p++) {
      const geo_element& K = omega.element(*p);
      ncol += quad.size (K);
    }
    const size_t n_extend_max = 3;
    for (size_t n_extend = 0; n_extend < n_extend_max &&
	((ncol <= nrow && d > 1) || (ncol < nrow && d == 1)); n_extend++) {
      size_t old_ncol = ncol;
      // extend Mi and re-count the number of sampling points ncol
      extend_macro_element (omega, macro_element, Mi);
      ncol = 0;
      for (set<geo::size_type>::const_iterator p = Mi.begin(); p != Mi.end(); p++) {
        const geo_element& K = omega.element(*p);
        ncol += quad.size (K);
      }
#ifdef TO_CLEAN
warning_macro ("old ncol="<<old_ncol<<", new ncol="<<ncol<<" and nrow="<<nrow);
#endif // TO_CLEAN
    }
    // step 1.b: compute bounding box of Mi
    point bbox[4];
    macro_element_bbox (omega, Mi, bbox);

    // step 2: compute B and f
    ublas::matrix<Float> B (nrow,ncol);
    ublas::vector<Float> w (ncol);
    ublas::vector<point> f (ncol);
    size_t icol = 0;
    Float meas_Mi = 0;
    for (set<geo::size_type>::const_iterator p = Mi.begin(); p != Mi.end(); p++) {
      size_t K_idx = *p;
      const geo_element& K = omega.element(K_idx);
      meas_Mi += omega.measure(K);
      reference_element hat_K (K);
      quadrature<Float>::const_iterator first_quad = quad.begin(hat_K);
      quadrature<Float>::const_iterator last_quad  = quad.end  (hat_K);
      for (size_t q = 0; first_quad != last_quad; first_quad++, q++, icol++) {
        meshpoint hat_xq (K_idx, (*first_quad).x);
        Float     hat_wq = (*first_quad).w;
        point xq = Hkh.dehatter (hat_xq);
        Float wq = hat_wq*piola.det_jacobian_transformation (K,q);
        w[icol] = wq;
        grad_uh.evaluate (hat_xq, f[icol]);
        point tilde_xq = macro_element_inv_piola (tilde_M, bbox, xq);
        phi.eval (tilde_M, tilde_xq, phi_xq);
        for (size_t irow = 0; irow < nrow; irow++) {
          B(irow,icol) = phi_xq [irow];
        }
      }
    }
#ifdef TO_CLEAN
    cerr << "w" << " = " << w << endl;
    cerr << "B" << " = " << B << endl;
#endif // TO_CLEAN
    // step 3: compute A=B*B^T and b=B*f
    ublas::matrix<Float> A (nrow,nrow);
    ublas::vector<point> b (nrow);
    for (size_t irow = 0; irow < nrow; irow++) {
      b[irow] = point(0.0);
      for (size_t kcol = 0; kcol < ncol; kcol++) {
        b[irow] = b[irow] + B(irow,kcol)*f[kcol]*w[kcol];
      }
      for (size_t jrow = 0; jrow < nrow; jrow++) {
        A(irow,jrow) = 0;
        for (size_t kcol = 0; kcol < ncol; kcol++) {
          A(irow,jrow) += B(irow,kcol)*B(jrow,kcol)*w[kcol];
        }
      }
    }
    // step 4: solv A*g = b
    // TODO: factorize LDLt is more efficient
    ublas::matrix<Float> inv_A (nrow, nrow);
    bool is_singular = ! invert (A, inv_A);
    // TODO: det(A) can be obtained also with the factorization
#ifdef TO_CLEAN
    Float det_A = determinant(A);
    cerr << "A" << " = " << A << endl;
    cerr << "inv_A" << " = " << inv_A << endl;
    cerr << "det_A" << " = " << det_A << endl;
    const Float eps_mach = std::numeric_limits<Float>::epsilon();
    // when P2, then det_A becomes very small => adimensionalize K_macro
    if (fabs(det_A) < eps_mach) is_singular = true;
#endif // TO_CLEAN
    if (ncol < nrow) is_singular = true;
    ublas::vector<point> g (nrow);
    if (!is_singular) {
      // g = inv_A*b;
      for (size_t irow = 0; irow < nrow; irow++) {
        g[irow] = point(0.0);
        for (size_t jrow = 0; jrow < nrow; jrow++) {
          g[irow] = g[irow] + inv_A(irow,jrow)*b[jrow];
        }
      }
    } else {
      n_singular++;
      warning_macro ("singular least-square macro-element matrix on " << i << "-th vertex ; macro-element size = "<<Mi.size()<<", ncol="<<ncol<<", nrow="<<nrow);
    }
    // step 5: set dof
    // step 5.a: set dof on vertices
    //           rh(xi) = sum_j g(j)*phi_j(xi)
    wh.at(i) = 1.0;
    point rh_xi = macro_element_vector_evaluate (xi, tilde_M, bbox, Mi, g, phi, is_singular, grad_uh, phi_xi);
    rh.set_vector_at (i, rh_xi);
    // step 5.b: set dof on edges (k=2 && element=e,t,T) when k=2
    if (k == 1) {
      continue;
    }
    // edges are elements when d=1
    const set<geo::size_type>& Ei = ((d==1) ? (macro_element[i]) : (macro_edge[i]));
    for (set<geo::size_type>::const_iterator q = Ei.begin(); q != Ei.end(); q++) {
      size_t E_idx = *q;
      const geo_element& E = ((d==1) ? (omega.element(E_idx)) : (edge[E_idx]));
      const point& a = omega.vertex(E[0]);
      const point& b = omega.vertex(E[1]);
      point xe = 0.5*(a+b);
      size_t E_dof = omega.n_vertex() + E_idx; // P2
      wh.at(E_dof) += meas_Mi;
      point rh_xe = macro_element_vector_evaluate (xe, tilde_M, bbox, Mi, g, phi, is_singular, grad_uh, phi_xe);
      rh.incr_vector_at (E_dof, meas_Mi*rh_xe);
    }
  }
  // weighted average values at edges, for P2
  if (k >= 2) {
    if (d == 1) {
        rh = rh/wh;
    } else {
      for (size_t l = 0; l < d; l++) {
        rh[l] = rh[l]/wh;
      }
    }
  }
warning_macro ("post macro: " << n_singular << "/" << omega.n_vertex() << " singular vertex");
warning_macro ("post macro done");
  return rh;
}
// --------------------------------------------------------------------------
// Part III. Main interface and error estimator
// --------------------------------------------------------------------------
field
postprocessing (const field& grad_uh, string method)
{
  if (method == "prol2") {
    return postprocessing_prol2 (grad_uh);
  } else {
    return postprocessing_macro (grad_uh);
  }
}
field
estim_vector (const field& grad_uh, string method) {
    const geo& omega = grad_uh.get_geo();
    field Rh_grad_uh = postprocessing (grad_uh, method);
    string grad_approx = grad_uh.get_approx();
    check_macro (grad_approx == "P0" || grad_approx == "P1d",
	"unexpected gradent approxaimation `" << grad_approx << "'");
    string up_grad_approx = (grad_approx == "P0") ? "P1d" : "P2d";
    const space& Lkm1h_vec = grad_uh.get_space();
    const space& Hkh_vec = Rh_grad_uh.get_space();
    // ------------------------------------
    // eta_h as Pkd
    // ------------------------------------
    // eta_h = (I-Rh) grad(uh) is Pkd
    //  since grad(uh) is P(k-1)d and Rh(grad(uh)) is Pk-C0
    space Lkh_vec (omega, up_grad_approx, "vector");
    form proj_0 (Lkm1h_vec, Lkh_vec, "mass");
    form proj_1 (Hkh_vec,   Lkh_vec, "mass");
    form inv_m (Lkh_vec, Lkh_vec, "inv_mass");
    field    grad_uh_Pkd = inv_m*(proj_0*grad_uh);
    field Rh_grad_uh_Pkd = inv_m*(proj_1*Rh_grad_uh);
    field eta_h = grad_uh_Pkd - Rh_grad_uh_Pkd;
    return eta_h;
}
// la norme de eta_h^2 donne une estim scalaire du carre de l'erreur locale:
//   eta_h_K^2 = int_K |eta_h(x)|^2 dx 
field estim2 (const field& grad_uh, string method) {
    field eta_v_h = estim_vector (grad_uh, method);
    field eta_2 = norm2_L2_local (eta_v_h);
    return eta_2;
}
}// namespace rheolef
