#!/usr/bin/perl
#
#########################################################################
#                                                                       #
# Open Dynamics Engine, Copyright (C) 2001,2002 Russell L. Smith.       #
# All rights reserved.  Email: russ@q12.org   Web: www.q12.org          #
#                                                                       #
# This library is free software; you can redistribute it and/or         #
# modify it under the terms of EITHER:                                  #
#   (1) The GNU Lesser General Public License as published by the Free  #
#       Software Foundation; either version 2.1 of the License, or (at  #
#       your option) any later version. The text of the GNU Lesser      #
#       General Public License is included with this library in the     #
#       file LICENSE.TXT.                                               #
#   (2) The BSD-style license that is included with this library in     #
#       the file LICENSE-BSD.TXT.                                       #
#                                                                       #
# This library is distributed in the hope that it will be useful,       #
# but WITHOUT ANY WARRANTY; without even the implied warranty of        #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the files    #
# LICENSE.TXT and LICENSE-BSD.TXT for more details.                     #
#                                                                       #
#########################################################################

#
# triangular matrix solver and factorizer code generator.
#
# SOLVER
# ------
#
# if L is an n x n lower triangular matrix (with ones on the diagonal), the
# solver solves L*X=B where X and B are n x m matrices. this is the core
# step in L*D*L' factorization. the algorithm is (in matlab):
#
#     for i=1:n
#       for j=1:m
#         X(i,j) = B(i,j) - L(i,1:i-1)*X(1:i-1,j);
#       end
#     end
#
# note that the ordering of the (i,j) loop is somewhat arbitrary. the only
# prerequisite to calculating element (i,j) of X is that all X(1:i-1,j) have
# have already been calcuated. this gives us some flexibility.
#
# the code generated below calculates X in N1 x N1 blocks. to speed up the
# innermost dot product loop, the outer product trick is used. for instance,
# to calculate the value of the 2x2 matrix ABCD below we first iterate over
# the vectors (a,b,c,d) and (e,f,g,h), computing ABCD = a*e+b*f+c*g+d*h.
# then A and B contain the dot product values needed in the algorithm, and
# C and D have almost all of it. the outer product trick reduces the number
# of memory loads required. in this example 16 loads are required, but if
# the simple dot product in the above algorithm is used then 32 loads are
# required. increasing N1 decreases the total number of loads, but only as long
# as we have enough temporary registers to keep the matrix blocks and vectors.
#
#                L                     *    X     =     B
#
#  [ .                               ]   [ e e ]     [ . . ]
#  [ . .                             ]   [ f f ]     [ . . ]
#  [ . . .                           ]   [ g g ]     [ . . ]
#  [ . . . .                         ]   [ h h ]     [ . . ]
#  [ a b c d .                       ]   [ A B ]  =  [ . . ]
#  [ a b c d . .                     ]   [ C D ]     [ . . ]
#  [ . . . . . . .                   ]   [ . . ]     [ . . ]
#  [ . . . . . . . .                 ]   [ . . ]     [ . . ]
#  [ . . . . . . . . .               ]   [ . . ]     [ . . ]
#
# note that L is stored by rows but X and B are stored by columns.
# the outer product loops are unrolled for extra speed.
#
# LDLT FACTORIZATION
# ------------------
#
# the factorization algorithm builds L incrementally by repeatedly solving
# the following equation:
#
#   [ L  0 ] [ D 0 ] [ L' l  ] = [ A  a ]    <-- n rows
#   [ l' e ] [ 0 d ] [ 0  e' ]   [ a' b ]    <-- m rows
#
#   [ L*D*L'   L*D*l         ] = [ A  a ]
#   [ l'*D*L'  l'*D*l+e*d*e' ]   [ a' b ]
#
# L*D*L'=A is an existing solution, and a,b are new rows/columns to add to A.
# we compute:
#
#   L * (Dl) = a
#   l = inv(D) * Dl
#   e*d*e' = b - l'*Dl       (m*m LDLT factorization)
#
#
# L-transpose solver
# ------------------
#
# the LT (L-transpose) solver uses the same logic as the standard L-solver,
# with a few tricks to make it work. to solve L^T*X=B we first remap:
#    L to Lhat : Lhat(i,j) = L(n-j,n-i)
#    X to Xhat : Xhat(i) = X(n-i)
#    B to Bhat : Bhat(i) = B(n-i)
# and then solve Lhat*Xhat = Bhat. the current LT solver only supports one
# right hand side, but that's okay as it is not used in the factorizer.
#
#############################################################################
#
# code generation parameters, set in a parameters file:
#    FNAME   : name of source file to generate - a .c file will be made
#    TYPE    : 'f' to build factorizer, 's' to build solver, 't' to build the
#              transpose solver.
#    N1      : block size (size of outer product matrix) (1..9)
#    UNROLL1 : solver inner loop unrolling factor (1..)
#    UNROLL2 : factorizer inner loop unrolling factor (1..)
#    MADD    : if nonzero, generate code for fused multiply-add (0,1)
#    FETCH   : how to fetch data in the inner loop:
#                0 - load in a batch (the `normal way')
#                1 - delay inner loop loads until just before they're needed
#
#############################################################################
#
# TODO
# ----
#
# * dFactorLDLT() is not so efficient for matrix sizes < block size, e.g.
#   redundant calls, zero loads, adds etc
#
#############################################################################
#
# NOTES:
#
# * on the pentium we can prefetch like this:
#     asm ("prefetcht0 %0" : : "m" (*Ai) );
#   but it doesn't seem to help much

require ("BuildUtil");

# get and check code generation parameters
error ("Usage: BuildLDLT <parameters-file>") if $#ARGV != 0;
do $ARGV[0];

if (!defined($FNAME) || !defined($TYPE) || !defined($N1) ||
    !defined($UNROLL1) || !defined($UNROLL2) || !defined($MADD) ||
    !defined($FETCH)) {
  error ("code generation parameters not defined");
}

# check parameters
error ("bad TYPE") if $TYPE ne 'f' && $TYPE ne 's' && $TYPE ne 't';
error ("bad N1") if $N1 < 1 || $N1 > 9;
error ("bad UNROLL1") if $UNROLL1 < 1;
error ("bad UNROLL2") if $UNROLL2 < 1;
error ("bad MADD") if $MADD != 0 && $MADD != 1;
error ("bad FETCH") if $FETCH < 0 && $FETCH > 1;

#############################################################################
# utility

# functions to handle delayed loading of p and q values.
# bit in the the `ploaded' and `qloaded' numbers record what has been loaded,
# so we dont load it again.

sub newLoads
{
  # bits in these numbers say what registers p and q have been loaded so far
  $ploaded = 0;
  $qloaded = 0;
}

sub loadedEverything
{
  $ploaded = 0xffffffff;
  $qloaded = 0xffffffff;
}

sub loadP # (i,loadcmd)
{
  my $i = $_[0];
  my $loadcmd = $_[1];
  return if ($ploaded & (1 << $i));
  output ($loadcmd);
  $ploaded |= (1 << $i);
}

sub loadQ # (i,loadcmd)
{
  my $i = $_[0];
  my $loadcmd = $_[1];
  return if ($qloaded & (1 << $i));
  output ($loadcmd);
  $qloaded |= (1 << $i);
}

#############################################################################
# make a fast L solve function.
# this function has a restriction that the leading dimension of X and B must
# be a multiple of the block size.

sub innerOuterProductLoop # (M,k,nrhs,increment)
{
  my $M=$_[0];
  my $k=$_[1];
  my $nrhs=$_[2];
  my $increment=$_[3];
  my ($i,$j);
  newLoads;
  if ($FETCH==0) {
    comment ("load p and q values");
    for ($i=1; $i<=$M; $i++) {
      if ($TYPE eq 't') {
        output ("p$i=ell[".ofs2(-($i-1),0,'lskip')."];\n");
        output ("q$i=ex[".ofs2(-($k),$i-1,'lskip')."];\n") if $i <= $nrhs;
      }
      else {
        output ("p$i=ell[".ofs2($k,$i-1,'lskip')."];\n");
        output ("q$i=ex[".ofs2($k,$i-1,'lskip')."];\n") if $i <= $nrhs;
      }
    }
    loadedEverything;
  }

  comment ("compute outer product and add it to the Z matrix");
  for ($i=1; $i<=$M; $i++) {
    for ($j=1; $j<=$nrhs; $j++) {
      if ($TYPE eq 't') {
        loadP ($i,"p$i=ell[".ofs2(-($i-1),0,'lskip')."];\n");
        loadQ ($j,"q$j=ex[".ofs2(-($k),$j-1,'lskip')."];\n");
      }
      else {
        loadP ($i,"p$i=ell[".ofs2($k,$i-1,'lskip')."];\n");
        loadQ ($j,"q$j=ex[".ofs2($k,$j-1,'lskip')."];\n");
      }
      my $var = $MADD ? "Z$i$j +=" : "m$i$j =";
      output ("$var p$i * q$j;\n");
    }
  }

  if ($TYPE eq 't') {
    if ($increment > 0) {
      output ("ell += lskip1;\n");
      output ("ex -= $increment;\n");
    }
    else {
      output ("ell += lskip1;\n");
    }
  }
  else {
    if ($increment > 0) {
      comment ("advance pointers");
      output ("ell += $increment;\n");
      output ("ex += $increment;\n");
    }
  }

  if ($MADD==0) {
    for ($i=1; $i<=$M; $i++) {
      for ($j=1; $j<=$nrhs; $j++) {
        output ("Z$i$j += m$i$j;\n");
      }
    }
  }
}


sub computeRows # (nrhs,rows)
{
  my $nrhs = $_[0];
  my $rows = $_[1];
  my ($i,$j,$k);

  comment ("compute all $rows x $nrhs block of X, from rows i..i+$rows-1");

  comment ("set the Z matrix to 0");
  for ($i=1; $i<=$rows; $i++) {
    for ($j=1; $j<=$nrhs; $j++) {
      output ("Z$i$j=0;\n");
    }
  }
  if ($TYPE eq 't') {
    output ("ell = L - i;\n");
  }
  else {
    output ("ell = L + i*lskip1;\n");
  }
  output ("ex = B;\n");

  comment ("the inner loop that computes outer products and adds them to Z");
  output ("for (j=i-$UNROLL1; j >= 0; j -= $UNROLL1) {\n");
  for ($k=0; $k < $UNROLL1; $k++) {
    innerOuterProductLoop ($rows,$k,$nrhs,($k==$UNROLL1-1) ? $UNROLL1 : 0);
  }

  comment ("end of inner loop");
  output ("}\n");

  if ($UNROLL1 > 1) {
    comment ("compute left-over iterations");
    output ("j += $UNROLL1;\n");
    output ("for (; j > 0; j--) {\n");
    innerOuterProductLoop ($rows,'0',$nrhs,1);
    output ("}\n");
  }

  comment ("finish computing the X(i) block");

  for ($j=1; $j<=$nrhs; $j++) {
    if ($TYPE eq 't') {
      output ("Z1$j = ex[".ofs1(-($j-1),'lskip')."] - Z1$j;\n");
      output ("ex[".ofs1(-($j-1),'lskip')."] = Z1$j;\n");
    }
    else {
      output ("Z1$j = ex[".ofs1($j-1,'lskip')."] - Z1$j;\n");
      output ("ex[".ofs1($j-1,'lskip')."] = Z1$j;\n");
    }
  }

  for ($i=2; $i<=$rows; $i++) {
    for ($j=1; $j<$i; $j++) {
      if ($TYPE eq 't') {
        output ("p$j = ell[".ofs2(-($i-1),$j-1,'lskip')."];\n");
      }
      else {
        output ("p$j = ell[".ofs2($j-1,$i-1,'lskip')."];\n");
      }
    }
    for ($j=1; $j<=$nrhs; $j++) {
      if ($TYPE eq 't') {
        output ("Z$i$j = ex[".ofs2(-($i-1),$j-1,'lskip')."] - Z$i$j");
      }
      else {
        output ("Z$i$j = ex[".ofs2($i-1,$j-1,'lskip')."] - Z$i$j");
      }
      for ($k=1; $k < $i; $k++) {
        output (" - p$k*Z$k$j");
      }
      output (";\n");
      if ($TYPE eq 't') {
        output ("ex[".ofs2(-($i-1),$j-1,'lskip')."] = Z$i$j;\n");
      }
      else {
        output ("ex[".ofs2($i-1,$j-1,'lskip')."] = Z$i$j;\n");
      }
    }
  }
}


sub makeFastL1Solve # ( number-of-right-hand-sides )
{
  my $nrhs = $_[0];
  my ($i,$j,$k);
  my $funcsuffix = ($TYPE eq 'f') ? "_$nrhs" : '';
  my $staticvoid = ($TYPE eq 'f') ? 'static void' : 'void';

  # function header
  if ($TYPE eq 't') {
    output (<<END);

/* solve L^T * x=b, with b containing 1 right hand side.
 * L is an n*n lower triangular matrix with ones on the diagonal.
 * L is stored by rows and its leading dimension is lskip.
 * b is an n*1 matrix that contains the right hand side.
 * b is overwritten with x.
 * this processes blocks of $N1.
 */

void dSolveL1T (const dReal *L, dReal *B, int n, int lskip1)
{
END
  }
  else {
    output (<<END);

/* solve L*X=B, with B containing $nrhs right hand sides.
 * L is an n*n lower triangular matrix with ones on the diagonal.
 * L is stored by rows and its leading dimension is lskip.
 * B is an n*$nrhs matrix that contains the right hand sides.
 * B is stored by columns and its leading dimension is also lskip.
 * B is overwritten with X.
 * this processes blocks of $N1*$N1.
 * if this is in the factorizer source file, n must be a multiple of $N1.
 */

$staticvoid dSolveL1$funcsuffix (const dReal *L, dReal *B, int n, int lskip1)
{
END
  }

  comment ("declare variables - Z matrix, p and q vectors, etc");
  output ("dReal ");
  for ($i=1; $i<=$N1; $i++) {
    for ($j=1; $j<=$nrhs; $j++) {
      output ("Z$i$j,");		# Z matrix
      output ("m$i$j,") if ! $MADD;	# temporary vars where multiplies put
    }
  }
  for ($i=1; $i<=$N1; $i++) {
    output ("p$i,");
    output ("q$i,") if $i <= $nrhs;
  }
  output ("*ex;\nconst dReal *ell;\n");
  output ("int ");
  for ($i=2; $i<$N1; $i++) {
    output ("lskip$i,");
  }
  output ("i,j;\n");

  if ($TYPE eq 't') {
    comment ("special handling for L and B because we're solving L1 *transpose*");
    output ("L = L + (n-1)*(lskip1+1);\n");
    output ("B = B + n-1;\n");
    output ("lskip1 = -lskip1;\n");
  }

  if ($N1 > 2) {
    comment ("compute lskip values");
    for ($i=2; $i<$N1; $i++) {
      output ("lskip$i = $i*lskip1;\n");
    }
  }

  comment ("compute all $N1 x $nrhs blocks of X");
  if ($TYPE eq 's' or $TYPE eq 't') {
    output ("for (i=0; i <= n-$N1; i+=$N1) {\n");
  }
  else {
    output ("for (i=0; i < n; i+=$N1) {\n");
  }
  computeRows ($nrhs,$N1);
  comment ("end of outer loop");
  output ("}\n");

  if ($TYPE eq 's' or $TYPE eq 't') {
    comment ("compute rows at end that are not a multiple of block size");
    output ("for (; i < n; i++) {\n");
    computeRows ($nrhs,1);
    output ("}\n");
  }

  output ("}\n");
}

#############################################################################
# make a fast L*D*L' factorizer

# code fragment: this factors an M x M block. if A_or_Z is 0 then it works
# on the $A matrix otherwise it works on the Z matrix. in either case it
# writes the diagonal entries into the `dee' vector.
# it is a simple implementation of the LDLT algorithm, with no tricks.

sub getA # (i,j,A,A_or_Z)
{
  my $i = $_[0];
  my $j = $_[1];
  my $A = $_[2];
  return $_[3] ? ('Z'.($i+1).($j+1)) : ($A.'['.ofs2($j,$i,'nskip').']');
}

sub miniLDLT # (A,A_or_Z,M)
{
  my ($i,$j,$k);
  my $A = $_[0];
  my $AZ = $_[1];
  my $M = $_[2];
  comment ("factorize $M x $M block " . ($AZ ? "Z,dee" : "$A,dee"));
  comment ("factorize row 1");
  output ("dee[0] = dRecip(".getA(0,0,$A,$AZ).");\n");
  for ($i=1; $i<$M; $i++) {
    comment ("factorize row ".($i+1));
    for ($j=1; $j<$i; $j++) {
      output (getA($i,$j,$A,$AZ)." -= ");
      for ($k=0; $k<$j; $k++) {
	output (" + ") if $k > 0;
	output (getA($i,$k,$A,$AZ)."*".getA($j,$k,$A,$AZ));
      }
      output (";\n");
    }
    output ("sum = 0;\n");
    for ($j=0; $j<$i; $j++) {
      output ("q1 = ".getA($i,$j,$A,$AZ).";\n");
      output ("q2 = q1 * dee[$j];\n");
      output (getA($i,$j,$A,$AZ)." = q2;\n");
      output ("sum += q1*q2;\n");
    }
    output ("dee[$i] = dRecip(".getA($i,$i,$A,$AZ)." - sum);\n");
  }
  comment ("done factorizing $M x $M block");
}


sub innerScaleAndOuterProductLoop # (M,k)
{
  my $M = $_[0];
  my $k = $_[1];
  my ($i,$j);
  for ($i=1; $i<=$M; $i++) {
    output ("p$i = ell[".ofs2($k,$i-1,'nskip')."];\n");
  }
  output ("dd = dee[$k];\n");
  for ($i=1; $i<=$M; $i++) {
    output ("q$i = p$i*dd;\n");
  }
  for ($i=1; $i<=$M; $i++) {
    output ("ell[".ofs2($k,$i-1,'nskip')."] = q$i;\n");
  }
  for ($i=1; $i<=$M; $i++) {
    for ($j=1; $j<=$i; $j++) {
      my $var = $MADD ? "Z$i$j +=" : "m$i$j =";
      output ("$var p$i*q$j;\n");
    }
  }
  if ($MADD==0) {
  for ($i=1; $i<=$M; $i++) {
    for ($j=1; $j<=$i; $j++) {
        output ("Z$i$j += m$i$j;\n");
      }
    }
  }
}


sub diagRows # (M)
{
  my $M=$_[0];
  comment ("scale the elements in a $M x i block at A(i,0), and also");
  comment ("compute Z = the outer product matrix that we'll need.");
  for ($i=1; $i<=$M; $i++) {
    for ($j=1; $j<=$i; $j++) {
      output ("Z$i$j = 0;\n");
    }
  }
  output ("ell = A+i*nskip1;\n");
  output ("dee = d;\n");
  output ("for (j=i-$UNROLL2; j >= 0; j -= $UNROLL2) {\n");
  for ($i=0; $i < $UNROLL2; $i++) {
    innerScaleAndOuterProductLoop ($M,$i);
  }
  output ("ell += $UNROLL2;\n");
  output ("dee += $UNROLL2;\n");
  output ("}\n");

  if ($UNROLL2 > 1) {
    comment ("compute left-over iterations");
    output ("j += $UNROLL2;\n");
    output ("for (; j > 0; j--) {\n");
    innerScaleAndOuterProductLoop ($M,0);
    output ("ell++;\n");
    output ("dee++;\n");
    output ("}\n");
  }
}


sub diagBlock # (M)
{
  my $M = $_[0];
  comment ("solve for diagonal $M x $M block at A(i,i)");
  for ($i=1; $i<=$M; $i++) {
    for ($j=1; $j<=$i; $j++) {
      output ("Z$i$j = ell[".ofs2($j-1,$i-1,'nskip')."] - Z$i$j;\n");
    }
  }
  output ("dee = d + i;\n");
  miniLDLT ('',1,$M);
  for ($i=2; $i<=$M; $i++) {
    for ($j=1; $j<$i; $j++) {
      output ("ell[".ofs2($j-1,$i-1,'nskip')."] = Z$i$j;\n");
    }
  }
}


sub makeFastLDLT
{
  my ($i,$j,$k);

  # function header
  output (<<END);


void dFactorLDLT (dReal *A, dReal *d, int n, int nskip1)
{
END
  output ("int i,j");
  for ($i=2; $i<$N1; $i++) {
    output (",nskip$i");
  }
  output (";\n");
  output ("dReal sum,*ell,*dee,dd,p1,p2");
  for ($i=3; $i<=$N1; $i++) {
    output (",p$i");
  }
  for ($i=1; $i<=$N1; $i++) {
    output (",q$i");
  }
  for ($i=1; $i<=$N1; $i++) {
    for ($j=1; $j<=$i; $j++) {
      output (",Z$i$j");
      output (",m$i$j") if ! $MADD;	# temporary vars where multiplies put
    }
  }
  output (";\n");
  output ("if (n < 1) return;\n");
  # output ("nskip1 = dPAD(n);\n");   ... not any more
  for ($i=2; $i<$N1; $i++) {
    output ("nskip$i = $i*nskip1;\n");
  }

  output ("\nfor (i=0; i<=n-$N1; i += $N1) {\n");
  comment ("solve L*(D*l)=a, l is scaled elements in $N1 x i block at A(i,0)");
  output ("dSolveL1_$N1 (A,A+i*nskip1,i,nskip1);\n");

  diagRows ($N1);
  diagBlock ($N1);
  output ("}\n");

  comment ("compute the (less than $N1) rows at the bottom");
  output ("switch (n-i) {\n");
  output ("case 0:\n");
  output ("break;\n\n");

  for ($i=1; $i<$N1; $i++) {
    output ("case $i:\n");
    output ("dSolveL1_$i (A,A+i*nskip1,i,nskip1);\n");
    diagRows ($i);
    diagBlock ($i);
    output ("break;\n\n");
  }

  output ("default: *((char*)0)=0;  /* this should never happen! */\n");
  output ("}\n");

  output ("}\n");
}

#############################################################################
# write source code

open (FOUT,">$FNAME.c") or die "can't open $FNAME.c for writing";

# file and function header
output (<<END);
/* generated code, do not edit. */

#include "ode/matrix.h"
END

if ($TYPE eq 'f') {
  for ($i=1; $i <= $N1; $i++) {
    makeFastL1Solve ($i);
  }
  makeFastLDLT;
}
else {
  makeFastL1Solve (1);
  makeRealFastL1Solve;
}
close FOUT;
