Raw File
t_invert4_precwilson.cc
// $Id: t_invert4_precwilson.cc,v 1.4 2004-03-25 14:31:20 bjoo Exp $

#include <iostream>
#include <sstream>
#include <iomanip>
#include <string>

#include <cstdio>

#include <stdlib.h>
#include <sys/time.h>
#include <math.h>

#include "chroma.h"
#include "invcg2_timing_hacks_2.h"
#include "invcg2_prec_wilson.h"

using namespace QDP;
using namespace std;


enum GaugeStartType { COLD_START=0, HOT_START };
struct Params_t { 
  multi1d<int> nrow;
  multi1d<int> boundary;
  GaugeStartType gauge_start_type;
};

void checkInverter(multi1d<LatticeColorMatrix>& u)
{
  LatticeFermion psi;
  LatticeFermion psi2;
  LatticeFermion chi;

  psi=zero;
  psi2=zero;

  multi1d<int> boundary(4);
  boundary[0] = 1; 
  boundary[1] = 1;
  boundary[2] = 1;
  boundary[3] = -1;

  Real mass = 0.5;
  Handle<FermBC<LatticeFermion> >  
    fbc(new SimpleFermBC<LatticeFermion>(boundary));
  
  EvenOddPrecWilsonFermAct  S_w(fbc, mass);
  
  // Apply boundary to u
  Handle<const ConnectState> connect_state(S_w.createState(u));
  Handle<const EvenOddPrecWilsonLinOp > D_op( dynamic_cast<const EvenOddPrecWilsonLinOp *> (S_w.linOp(connect_state)) );

  // Get Initial Vector 
  gaussian(chi);

  int n_count;
  QDPIO::cout << "Running General Solver with LinOp from FermAct" << endl;
  InvCG2(*D_op,
	 chi,
	 psi,
	 Real(1.0e-7),
	 100000,
	 n_count);
  
 

  // Get packed gauge field from connect state
  int n_count2;
  const multi1d<LatticeColorMatrix> gauge_with_bc=(*connect_state).getLinks();
  WilsonDslash dsl(gauge_with_bc);
  QDPIO::cout << "Running SuperSpecialised Solver with Dslash" << endl;
  InvCG2EvenOddPrecWilsLinOp(dsl,
                             chi,
                             psi2,
                             mass,
                             Real(1.0e-7),
                             16,
                             n_count2);

  QDPIO::cout << "General Solver took " << n_count << " iters" << endl;
  QDPIO::cout << "Super   Solver took " << n_count2 << " iters " << endl;

  LatticeFermion r;
  r[rb[1]] = psi2 - psi;
  Double chi_norm_diff = norm2(r, rb[1]);
  QDPIO::cout << " || chi2 - chi || = " << chi_norm_diff << endl;

}

int main(int argc, char **argv)
{
  // Put the machine into a known state
  QDP_initialize(&argc, &argv);

  Params_t params;

  // Read params
  XMLReader reader("DATA");
  string stype;
  try { 
    read(reader, "/t_invert/params/nrow", params.nrow);
    read(reader, "/t_invert/params/gauge_start_type", stype);
  }
  catch(const string &error) { 
    QDPIO::cerr << "Error : " << error << endl;
    throw;
  }
  reader.close();

  if( stype == "HOT_START" ) { 
    params.gauge_start_type = HOT_START;
  }
  else if( stype == "COLD_START" ) { 
    params.gauge_start_type = COLD_START;
  }

  QDPIO::cout << "Gauge start type " ;
  switch (params.gauge_start_type) { 
  case HOT_START:
    QDPIO::cout << "hot start" << endl;
    break;
  case COLD_START:
    QDPIO::cout << "cold start" << endl;
    break;
  default:
    QDPIO::cout << endl;
    QDPIO::cerr << "Unknown gauge start type " << endl;
  }

  params.boundary.resize(4);
  params.boundary[0] = 1;
  params.boundary[1] = 1;
  params.boundary[2] = 1;
  params.boundary[3] = 1;


  // Setup the lattice
  Layout::setLattSize(params.nrow);
  Layout::create();

  XMLFileWriter xml("XMLDAT");
  push(xml,"t_invert");
  push(xml,"params");
  write(xml, "nrow", params.nrow);
  write(xml, "boundary", params.boundary);
  write(xml, "gauge_start_type", stype);
  pop(xml); // Params

  // Create a FermBC
  Handle<FermBC<LatticeFermion> >  fbc(new SimpleFermBC<LatticeFermion>(params.boundary));
  
  // The Gauge Field
  multi1d<LatticeColorMatrix> u(Nd);
  
  switch ((GaugeStartType)params.gauge_start_type) { 
  case COLD_START:
    for(int j = 0; j < Nd; j++) { 
      u(j) = Real(1);
    }
    break;
  case HOT_START:
    // Hot (disordered) start
    for(int j=0; j < Nd; j++) { 
      random(u(j));
      reunit(u(j));
    }
    break;
  default:
    ostringstream startup_error;
    startup_error << "Unknown start type " << params.gauge_start_type <<endl;
    throw startup_error.str();
  }


  // Measure the plaquette on the gauge
  Double w_plaq, s_plaq, t_plaq, link;
  MesPlq(u, w_plaq, s_plaq, t_plaq, link);
  push(xml, "plaquette");
  write(xml, "w_plaq", w_plaq);
  write(xml, "s_plaq", s_plaq);
  write(xml, "t_plaq", t_plaq);
  write(xml, "link", link);
  pop(xml);


  checkInverter(u);

  WilsonDslash D(u);

  LatticeFermion chi;
  LatticeFermion psi;
  StopWatch swatch;
  Double mydt;
  int iter;
  Real mass = Real(0);
  for(iter=1; ; iter <<= 1)
  {
    psi = zero;
    QDPIO::cout << "Let 0 action inverter iterate "<< iter << " times" << endl;

    gaussian(chi);
    swatch.reset();
    swatch.start();

    InvCG2EvenOddPrecWilsLinOpTHack(D,
			       chi,
			       psi,
			       mass, 
			       Real(1.0e-6),
			       10000,
			       iter);

    swatch.stop();
                                                                                
    mydt=Double(swatch.getTimeInSeconds());

    Internal::globalSum(mydt);

    mydt /= Double(Layout::numNodes());
                                                                                
    QDPIO::cout << "Time was " << mydt << " seconds" << endl;

    if ( toBool(mydt > Double(1) ) )
      break;
  }


  // Snarfed from SZIN
  int  N_dslash = 1320;
  int  N_mpsi   = 2*12 + 2*24 + 2*N_dslash;
  int  Nflops_c = (24 + 2*N_mpsi) + (48);     
  int Nflops_s = (2*N_mpsi + (2*48+2*24));   
  Double Nflops;

  multi1d<Double> mflops(10);
  multi1d<Double> mydt_a(10);
  
  for (int j=0; j < 10; ++j)
  {
    psi = zero;

    swatch.reset();
    swatch.start();

    InvCG2EvenOddPrecWilsLinOpTHack(D,
                                    chi,
                                    psi,
                                    mass,
                                    Real(1.0e-6),
                                    10000,
                                    iter);

    swatch.stop();
										    mydt=Double(swatch.getTimeInSeconds());

    Internal::globalSum(mydt);

    mydt /= Double(Layout::numNodes());
                                                                                
    mydt_a[j] = Double(1.0e6)*mydt/(Double(Layout::sitesOnNode())/Double(2));
                                                                                
    // Flop count for inverter 
    Nflops   = Double(Nflops_c) + Double(iter*Nflops_s);
    mflops[j] = Nflops / mydt_a[j];
  }

  mydt=1.0e6f*mydt/((Double)(Layout::sitesOnNode())/Double(2));

  push(xml, "TimeCG2");
  write(xml, "iter", iter);
  write(xml, "N_dslash", N_dslash);
  write(xml, "N_mpsi", N_mpsi);
  write(xml, "Nflops", Nflops);
  write(xml, "mydt_a", mydt_a);
  write(xml, "mflops_a", mflops);
  pop(xml);
  
  pop(xml);
  xml.close();

  QDP_finalize();
    
  exit(0);
}
back to top