#!/usr/bin/perl

# pcFactorStan is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# pcFactorStan is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with pcFactorStan.  If not, see <http://www.gnu.org/licenses/>.

use Modern::Perl '2018';
use Fatal qw(open rename unlink);
use File::Compare;

die "Must be run from top level" if !-e './tools/genStan';

my $commonData =
'  real alphaScalePrior;
  // dimensions
  int<lower=1> NPA;             // worths or players or objects or things
  int<lower=1> NCMP;            // unique comparisons
  int<lower=1> N;               // observations
  int<lower=1> numRefresh;      // when change in item/pa1/pa2';

my $responseData =
'  // response data
  array[numRefresh] int<lower=1, upper=NPA> pa1;
  array[numRefresh] int<lower=1, upper=NPA> pa2;
  array[NCMP] int weight;
  array[NCMP] int pick;
  array[numRefresh] int refresh;
  array[numRefresh] int numOutcome;';

my $looDecl =
  '
  array[N] real log_lik;
';

my $unidim_ll =
'{
    int cmpStart = 1;
    int cur = 1;

    for (rx in 1:numRefresh) {
      int nout = numOutcome[rx];
      int last = cur + nout - 1;
      log_lik[cur:last] =
        pairwise_loo(rcat, weight, nout, cmpStart, refresh[rx],
                     scale, alpha, theta[pa1[rx]], theta[pa2[rx]], rawCumTh);
      cmpStart += refresh[rx];
      cur += numOutcome[rx];
    }
  }
';

my $threshold_prior =
  'rawThreshold ~ beta(1.1, 2);';

my $alpha_prior =
  'normal(1.749, alphaScalePrior) T[0,]';

sub mkUnidim {
    my ($adapt, $ll) = @_;
    my $data = $adapt? "varCorrection" : "scale";
    my $xdata = $adapt? "\n  real alpha = 1.749;" : '';
    my $par = $adapt? "sigma" : "alpha";
    my $scaleDef = $adapt? "
  vector[NPA] theta = sqrt(sigma) * rawTheta;
  real scale = sd(theta) ^ varCorrection;" : "";
    my $theta = $adapt? "theta" : "rawTheta";
    my $prior = $adapt? 'sigma ~ inv_gamma(1, 1);' : "alpha ~ $alpha_prior;";
    my $llDecl = $ll? $looDecl : '';
    my $llBody = $ll? $unidim_ll : '';
    my $thetaPP = $adapt? "real thetaVar = variance(theta);" : "vector[NPA] theta = rawTheta;
  theta -= mean(theta);
  theta /= sd(theta);";

    my $unidim =
qq[data {
$commonData
  int<lower=1> NTHRESH;         // number of thresholds
  real $data;
$responseData
}
transformed data {
  array[NCMP] int rcat;$xdata

  for (cmp in 1:NCMP) {
    rcat[cmp] = pick[cmp] + NTHRESH + 1;
  }
}
parameters {
  vector[NPA] rawTheta;
  vector<lower=0,upper=1>[NTHRESH] rawThreshold;
  real<lower=0> $par;
}
transformed parameters {$scaleDef
  vector[NTHRESH] threshold;
  vector[NTHRESH] rawCumTh;
  real maxSpan = max($theta) - min($theta);
  threshold = maxSpan * rawThreshold;
  rawCumTh = cumulative_sum(threshold);
}
model {

  $prior
  rawTheta ~ std_normal();
  $threshold_prior

  {
    int cmpStart = 1;
    for (rx in 1:numRefresh) {
      target += pairwise_logprob(rcat, weight, cmpStart, refresh[rx],
                                 scale, alpha, ${theta}[pa1[rx]], ${theta}[pa2[rx]], rawCumTh);
      cmpStart += refresh[rx];
    }
  }
}
generated quantities {$llDecl
  $thetaPP
$llBody}
];
}

sub mvtDataCommon {
    my ($data, $tdataDecl, $tdata) = @_;
    $data = '' if !defined $data;
    $tdataDecl = '' if !defined $tdataDecl;
    $tdata = '' if !defined $tdata;
qq[data {
$commonData
  int<lower=1> NITEMS;
  array[NITEMS] int<lower=1> NTHRESH;         // number of thresholds
  array[NITEMS] int<lower=1> TOFFSET;
  vector[NITEMS] scale;$data
$responseData
  array[numRefresh] int item;
}
transformed data {
  int totalThresholds = sum(NTHRESH);
  array[NCMP] int rcat;$tdataDecl
  {
    int cmpStart = 0;
    for (rx in 1:numRefresh) {
      int ix = item[rx];
      for (cmp in 1:refresh[rx]) {
        rcat[cmpStart + cmp] = pick[cmpStart + cmp] + NTHRESH[ix] + 1;
      }
      cmpStart += refresh[rx];
    }
  }$tdata
}];
}

my $multivariateThresholdParam =
  'vector<lower=0,upper=1>[totalThresholds] rawThreshold;';

my $multivariateThresholdTParamDecl =
  'vector[totalThresholds] threshold;
  vector[totalThresholds] rawCumTh;';

my $multivariateThresholdTParam =
'for (ix in 1:NITEMS) {
    real maxSpan = max(theta[,ix]) - min(theta[,ix]);
    int from = TOFFSET[ix];
    int to = TOFFSET[ix] + NTHRESH[ix] - 1;
    threshold[from:to] = maxSpan * rawThreshold[from:to];
    rawCumTh[from:to] = cumulative_sum(threshold[from:to]);
  }';

my $multivariateQuickLikelihood =
'{
    int cmpStart = 1;
    for (rx in 1:numRefresh) {
      int ix = item[rx];
      int from = TOFFSET[ix];
      int to = TOFFSET[ix] + NTHRESH[ix] - 1;
      target += pairwise_logprob(rcat, weight, cmpStart, refresh[rx],
                                 scale[ix], alpha[ix], theta[pa1[rx], ix],
                                 theta[pa2[rx], ix], rawCumTh[from:to]);
      cmpStart += refresh[rx];
    }
  }';

my $multivariateLoo = '

  {
    int cmpStart = 1;
    int cur = 1;
    for (rx in 1:numRefresh) {
      int ix = item[rx];
      int from = TOFFSET[ix];
      int to = TOFFSET[ix] + NTHRESH[ix] - 1;
      int nout = numOutcome[rx];
      int last = cur + nout - 1;
      log_lik[cur:last] =
        pairwise_loo(rcat, weight, nout, cmpStart, refresh[rx],
                     scale[ix], alpha[ix],
                     theta[pa1[rx],ix], theta[pa2[rx],ix], rawCumTh[from:to]);
      cmpStart += refresh[rx];
      cur += numOutcome[rx];
    }
  }';

sub mkCorModel {
    my ($ll) = @_;
    my $llDecl = $ll? $looDecl : '';
    my $llBody = $ll? $multivariateLoo : '';
    mvtDataCommon('
  real corLKJPrior;').
qq[
parameters {
  $multivariateThresholdParam
  vector<lower=0>[NITEMS] alpha;
  matrix[NPA,NITEMS]      rawTheta;
  cholesky_factor_corr[NITEMS] rawThetaCorChol;
}
transformed parameters {
  $multivariateThresholdTParamDecl
  matrix[NPA,NITEMS]      theta;

  // non-centered parameterization due to thin data
  for (pa in 1:NPA) {
    theta[pa,] = (rawThetaCorChol * rawTheta[pa,]')';
  }
  $multivariateThresholdTParam
}
model {
  rawThetaCorChol ~ lkj_corr_cholesky(corLKJPrior);
  for (pa in 1:NPA) {
    rawTheta[pa,] ~ std_normal();
  }
  $threshold_prior
  for (ix in 1:NITEMS) alpha[ix] ~ $alpha_prior;
  $multivariateQuickLikelihood
}
generated quantities {$llDecl
  corr_matrix[NITEMS] thetaCor;
  thetaCor = multiply_lower_tri_self_transpose(rawThetaCorChol);$llBody
}
];
}

sub mkFactorModel {
    my ($ll, $psi) = @_;
    my $llDecl = $ll? $looDecl : '';
    my $llBody = $ll? $multivariateLoo : '';

    my $psiParam = '';
    my $psiTparamDecl = '';
    my $psiTparam = '';
    my $rawFactor = 'rawFactor';
    my $latentPrior = '
  rawFactor[,1] ~ std_normal();';
    my $psiGenQuantDecl = '';
    my $psiFlipSign = '';
    if ($psi) {
      $psiParam = '
  cholesky_factor_corr[NFACTORS] CholPsi;';
      $psiTparamDecl = '
  matrix[NPA,NFACTORS] rawFactorPsi;';
      $psiTparam = "
  rawFactorPsi = rawFactor * CholPsi';";
      $rawFactor = 'rawFactorPsi';
      $latentPrior = '
  CholPsi ~ lkj_corr_cholesky(2.5);
  for (xx in 1:NPA) {
    rawFactor[xx,] ~ std_normal();
  }';
      $psiGenQuantDecl = "
  matrix[NFACTORS,NFACTORS] Psi = CholPsi * CholPsi';";
      $psiFlipSign = '
      Psi[,fx] = -Psi[,fx];
      Psi[fx,] = -Psi[fx,];
      Psi[fx,fx] = -Psi[fx,fx];';
    }

    mvtDataCommon("
  real propShape;
  int<lower=1> NFACTORS;
  array[NFACTORS] real factorScalePrior;
  int<lower=1> NPATHS;
  array[2,NPATHS] int factorItemPath;  // 1 is factor index, 2 is item index",
		 '
  vector[NPATHS] pathScalePrior;',
		 '
  for (px in 1:NPATHS) {
    int fx = factorItemPath[1,px];
    int ix = factorItemPath[2,px];
    if (fx < 1 || fx > NFACTORS) {
      reject("factorItemPath[1,","px","] names factor ", fx, " (NFACTORS=",NFACTORS,")");
    }
    if (ix < 1 || ix > NITEMS) {
      reject("factorItemPath[2,","px","] names item ", ix, " (NITEMS=",NITEMS,")");
    }
    pathScalePrior[px] = factorScalePrior[fx];
  }').
    qq[
parameters {
  array[NITEMS] real<lower=0> alpha;
  $multivariateThresholdParam$psiParam
  matrix[NPA,NFACTORS] rawFactor;      // do not interpret, see factor
  vector<lower=0,upper=1>[NPATHS] rawLoadings; // do not interpret, see factorLoadings
  matrix[NPA,NITEMS] rawUniqueTheta; // do not interpret, see uniqueTheta
  vector<lower=0,upper=1>[NITEMS] rawUnique;      // do not interpret, see unique
}
transformed parameters {
  $multivariateThresholdTParamDecl$psiTparamDecl
  matrix[NPA,NITEMS] theta;
  vector[NPATHS] rawPathProp;  // always positive
  array[NITEMS,1+NFACTORS] real rawPerComponentVar;
  for (ix in 1:NITEMS) {
    theta[,ix] = rawUniqueTheta[,ix] * (2*rawUnique[ix]-1);
    rawPerComponentVar[ix, 1] = variance(theta[,ix]);
  }
  for (fx in 1:NFACTORS) {
    for (ix in 1:NITEMS) rawPerComponentVar[ix,1+fx] = 0;
  }$psiTparam
  for (px in 1:NPATHS) {
    int fx = factorItemPath[1,px];
    int ix = factorItemPath[2,px];
    vector[NPA] theta1 = (2*rawLoadings[px]-1) * ${rawFactor}[,fx];
    rawPerComponentVar[ix,1+fx] = variance(theta1);
    theta[,ix] += theta1;
  }
  for (px in 1:NPATHS) {
    int fx = factorItemPath[1,px];
    int ix = factorItemPath[2,px];
    real resid = 0;
    real pred;
    for (cx in 1:(1+NFACTORS)) {
      if (cx == fx+1) {
        pred = rawPerComponentVar[ix,cx];
      } else {
        resid += rawPerComponentVar[ix,cx];
      }
    }
    rawPathProp[px] = pred / (pred + resid);
  }
  $multivariateThresholdTParam
}
model {
  for (ix in 1:NITEMS) alpha[ix] ~ $alpha_prior;
  $threshold_prior$latentPrior
  rawLoadings ~ beta(propShape, propShape);
  rawUnique ~ beta(propShape, propShape);
  for (ix in 1:NITEMS) {
    rawUniqueTheta[,ix] ~ std_normal();
  }
  $multivariateQuickLikelihood
  // 1.0 excessive, 1.5 not enough
  target += normal_lpdf(logit(0.5 + rawPathProp/2.0) | 0, pathScalePrior);
}
generated quantities {$llDecl
  vector[NPATHS] pathProp = rawPathProp;
  matrix[NPA,NFACTORS] factor = ${rawFactor};
  matrix[NITEMS,NITEMS] residualItemCor;$psiGenQuantDecl

  {
    matrix[NPA,NITEMS] residual;
    for (ix in 1:NITEMS) {
      residual[,ix] = rawUniqueTheta[,ix] * (2*rawUnique[ix]-1);
      residual[,ix] -= mean(residual[,ix]);
      residual[,ix] /= sd(residual[,ix]);
    }
    residualItemCor = crossprod(residual);
    residualItemCor = quad_form_diag(residualItemCor, 1.0 ./ sqrt(diagonal(residualItemCor)));
  }

  {
    vector[NPATHS] pathLoadings = (2*rawLoadings-1);
    array[NFACTORS] int rawSeenFactor;
    array[NFACTORS] int rawNegateFactor;
    for (fx in 1:NFACTORS) rawSeenFactor[fx] = 0;
    for (px in 1:NPATHS) {
      int fx = factorItemPath[1,px];
      int ix = factorItemPath[2,px];
      if (rawSeenFactor[fx] == 0) {
        rawSeenFactor[fx] = 1;
        rawNegateFactor[fx] = rawLoadings[px] < 0.5;
      }
      if (rawNegateFactor[fx]) {
        pathLoadings[px] = -pathLoadings[px];
      }
    }
    for (fx in 1:NFACTORS) {
      if (!rawNegateFactor[fx]) continue;
      factor[,fx] = -factor[,fx];$psiFlipSign
    }
    for (fx in 1:NPATHS) {
      if (pathLoadings[fx] < 0) pathProp[fx] = -pathProp[fx];
    }
  }
  for (fx in 1:NFACTORS) {
    factor[,fx] -= mean(factor[,fx]);
    factor[,fx] /= sd(factor[,fx]);
  }$llBody
}
];
}

sub openAndPreamble {
    my ($f) = @_;
    open my $fh, ">inst/stan/$f.stan.new";
    print $fh '#include /pre/license.stan
functions {
#include /functions/pairwise.stan
}
';
    $fh
}

{
    my $fh = openAndPreamble('unidim');
    print $fh mkUnidim(0,0);
}
{
    my $fh = openAndPreamble('unidim_adapt');
    print $fh mkUnidim(1,0);
}
{
    my $fh = openAndPreamble('unidim_ll');
    print $fh mkUnidim(0,1);
}
{
    my $fh = openAndPreamble('correlation');
    print $fh mkCorModel(0);
}
{
    my $fh = openAndPreamble('correlation_ll');
    print $fh mkCorModel(1);
}
{
    my $fh = openAndPreamble('factor1');
    print $fh mkFactorModel(0,0);
}
{
    my $fh = openAndPreamble('factor1_ll');
    print $fh mkFactorModel(1,0);
}
{
    my $fh = openAndPreamble('factor');
    print $fh mkFactorModel(0,1);
}
{
    my $fh = openAndPreamble('factor_ll');
    print $fh mkFactorModel(1,1);
}

for my $f (qw(unidim unidim_adapt unidim_ll
 correlation correlation_ll factor1 factor1_ll
 factor factor_ll)) {
    my $old = "inst/stan/$f.stan";
    my $new = "$old.new";
    if (-e $old && compare($old, $new) == 0) {
	unlink($new)
    } else {
	rename($new, $old)
    }
}
