https://github.com/linbox-team/fflas-ffpack
Raw File
Tip revision: a7801a65e9972b71558322e43812f5a7e08bbb4d authored by Clement Pernet on 14 November 2017, 16:52:10 UTC
fix parallel transpose
Tip revision: a7801a6
testeur_fgemm.C
/* -*- mode: C++; tab-width: 8; indent-tabs-mode: t; c-basic-offset: 8 -*- */
//--------------------------------------------------------------------------
//                        Test for the  fgemm winograd 
//                  
//--------------------------------------------------------------------------
// Clement Pernet
//-------------------------------------------------------------------------

//#define NEWWINO

#include <iostream>
#include <iomanip>
using namespace std;

#include "givaro/modular.h"
//#include "timer.h"
#include "fflas-ffpack/utils/fflas_io.h"
#include "fflas-ffpack/fflas/fflas.h"


#include "givaro/givintprime.h"
#include "givaro/modular.h"
#include "givaro/gfq.h"

using namespace FFPACK;
using namespace Givaro;

//typedef ModularBalanced<float> Field; 
//typedef ModularBalanced<double> Field;
typedef Givaro::Modular<double> Field;
//typedef Givaro::Modular<float> Field;
//typedef Givaro::Modular<int> Field; //-> bug avec w>=1 (olddynamic pealing)
//typedef Givaro::Modular<int32_t> Field;
//typedef GFqDom<int32_t> Field;

int main(int argc, char** argv){
 FFLAS::Timer tim;
	IntPrimeDom IPD;
	Field::Element alpha, beta;
	long p;
	size_t M, K, N, Wino;
	bool keepon = true;
	Integer _p,tmp;
	cerr<<setprecision(10);
	size_t TMAXM = 100, TMAXK = 100, TMAXN = 100;
	size_t PRIMESIZE = 23;
	size_t WINOMAX = 8;
	
	if (argc > 1 )
		TMAXM = atoi(argv[1]);
	if (argc > 2 )
		PRIMESIZE = atoi(argv[2]);
	if (argc > 3 )
		WINOMAX = atoi(argv[3]);
	if (argc > 4 )
		TMAXK = atoi(argv[4]);
	else
        TMAXK = TMAXM;
	if (argc > 5 )
		TMAXN = atoi(argv[5]);
    else
        TMAXN = TMAXM;


	enum FFLAS::FFLAS_TRANSPOSE ta, tb;
	size_t lda,ldb;
	Field::Element * A;
	Field::Element * B;
	Field::Element * C, *Cbis, *Cter;
	
	while (keepon){
		srandom(_p);
		do{
			//		max = Integer::random(2);
			_p = random();//max % (2<<30);
			IPD.prevprime( tmp, (_p% (1<<PRIMESIZE)) );
			p =  tmp;
			
		}while( (p <= 2) );

		Field F( (size_t) p );
		Field::RandIter RValue( F );
		//NonzeroRandIter<Field> RnValue( F, RValue );
		
		
		do{
			M = (size_t)  random() % TMAXM;
			K = (size_t)  random() % TMAXK;
			N = (size_t)  random() % TMAXN;
			Wino = random() % WINOMAX;
		} while (!( (K>>Wino > 0) && (M>>Wino > 0) && (N>>Wino > 0) ));

		if (random()%2){
			ta = FFLAS::FflasTrans;
			lda = M;
		}
		else{
			ta = FFLAS::FflasNoTrans;
			lda = K;
		}
		if (random()%2){
			tb = FFLAS::FflasTrans;
			ldb = K;
		}
		else{
			tb = FFLAS::FflasNoTrans;
			ldb = N;
		}
		
		A = FFLAS::fflas_new<Field::Element>(M*K);
		B = FFLAS::fflas_new<Field::Element>(K*N);
		C = FFLAS::fflas_new<Field::Element>(M*N);
		Cbis = FFLAS::fflas_new<Field::Element>(M*N);
		Cter = FFLAS::fflas_new<Field::Element>(M*N);
		
		for( size_t i = 0; i < M*K; ++i )
			RValue.random( *(A+i) );
		for( size_t i = 0; i < K*N; ++i )
			RValue.random( *(B+i) );
		for( size_t i = 0; i < M*N; ++i )
			*(Cter+i) = *(Cbis+i)= RValue.random( *(C+i) );
		
		RValue.random( alpha );
		RValue.random( beta );
		
		cout <<"p = "<<(size_t)p<<" M = "<<M
		     <<" K = "<<K
             <<" N = "<<N
             <<" Winolevel = "<<Wino<<" "
		     <<alpha
		     <<((ta==FFLAS::FflasNoTrans)?".Ax":".A^Tx")
		     <<((tb==FFLAS::FflasNoTrans)?"B + ":"B^T + ")
		     <<beta<<".C"
		     <<"....";

		tim.clear();
		tim.start();
		FFLAS::MMHelper<Field, FFLAS::MMHelperAlgo::Winograd> WH (F,Wino,FFLAS::ParSeqHelper::Sequential());
		FFLAS::fgemm (F, ta, tb, M, N, K, alpha, A, lda, B, ldb, beta, C, N, WH);
		tim.stop();
		
// 		for (int j = 0; j < n; ++j ){
// 			FFLAS::fgemv( F, FFLAS::FflasNoTrans, m, k, alpha, A, k, B+j, n, beta, Cbis+j, n);
// 			for (int i=0; i<m; ++i)
// 				if ( !F.areEqual( *(Cbis+i*n+j), *(C+i*n+j) ) ) 
// 					keepon = false;
// 		}
		Field::Element aij, bij, temp;
		//F.div(boa, beta, alpha);
		for (size_t i = 0; i < M; ++i )
			for ( size_t j = 0; j < N; ++j ){
				//				F.mulin(*(Cbis+i*N+j),boa);
				F.mulin(*(Cbis+i*N+j),beta);
				F.assign(temp,F.zero);
				for ( size_t l = 0; l < K ; ++l ){
					if ( ta == FFLAS::FflasNoTrans )
						aij = *(A+i*lda+l);
					else
						aij = *(A+l*lda+i);
					if ( tb == FFLAS::FflasNoTrans )
						bij = *(B+l*ldb+j);
					else
						bij = *(B+j*ldb+l);
					F.axpyin(temp,aij,bij);
					//F.axpyin( *(Cbis+i*N+j), aij, bij );
				}
				F.axpyin( *(Cbis+i*N+j), alpha, temp);
				//F.mulin( *(Cbis+i*N+j),alpha );
				if ( !F.areEqual( *(Cbis+i*N+j), *(C+i*N+j) ) ) {
					cerr<<"error for i,j="<<i<<" "<<j<<" "<<*(C+i*N+j)<<" "<<*(Cbis+i*N+j)<<" diff = "<< *(C+i*N+j)-*(Cbis+i*N+j) <<endl;
					keepon = false;
				}
			}
		
		if (keepon){
			cout<<"Passed "
			    <<(2*M*N/1000.0*K/tim.usertime()/1000.0)<<"Mfops"<<endl; 
			FFLAS::fflas_delete( A);
			FFLAS::fflas_delete( B);
			FFLAS::fflas_delete( C);
			FFLAS::fflas_delete( Cbis);
			FFLAS::fflas_delete( Cter);
		}
		else{
			// cerr<<"C="<<endl;
// 			FFLAS::WriteMatrix (cerr, F, M, N, C, N );
// 			cerr<<"Cbis="<<endl;
// 			FFLAS::WriteMatrix (cerr, F, M, N, Cbis, N );
		}
	}
	cout<<endl;
	cerr<<"FAILED with p = "<<(size_t)p<<" M = "<<M<<" N = "<<N<<" K = "<<K
	    <<" Winolevel = "<<Wino
	    <<" alpha = "<<(int)alpha<<" beta = "<<(int)beta<<endl; 

	if (M < 100 && N < 100) {
		cerr << "error locations (X)" << endl;
		Field F( (size_t) p );
		for (size_t i = 0; i < M; ++i ) {
			for ( size_t j = 0; j < N; ++j ){
				if ( !F.areEqual( *(Cbis+i*N+j), *(C+i*N+j) ) ) {
					cerr<<"x" ;
				}
				else
					cerr<<"." ;
			}
			cerr << endl;
		}
		cerr << endl;

	}

	cerr<<"A:"<<endl;
	if ( ta ==FFLAS::FflasNoTrans ){
		cerr<<M<<" "<<K<<" M"<<endl;
		for (size_t i=0; i<M; ++i)
			for (size_t j=0; j<K; ++j)
				cerr<<i+1<<" "<<j+1<<" "<<((int) *(A+i*lda+j) )<<endl;
	}
	else{
		cerr<<K<<" "<<M<<" M"<<endl;
		for (size_t i=0; i<K; ++i)
			for (size_t j=0; j<M; ++j)
				cerr<<i+1<<" "<<j+1<<" "<<((int) *(A+j*lda+i) )<<endl;

	}
	cerr<<"0 0 0"<<endl<<endl;
	cerr<<"B:"<<endl;
	if ( tb ==FFLAS::FflasNoTrans ){
		cerr<<K<<" "<<N<<" M"<<endl;
		for (size_t i=0; i<K; ++i)
			for (size_t j=0; j<N; ++j)
				cerr<<i+1<<" "<<j+1<<" "<<((int) *(B+i*ldb+j) )<<endl;
	}
	else{
		cerr<<N<<" "<<K<<" M"<<endl;
		for (size_t i=0; i<N; ++i)
			for (size_t j=0; j<K; ++j)
				cerr<<i+1<<" "<<j+1<<" "<<((int) *(B+i+j*ldb) )<<endl;
	}
	cerr<<"0 0 0"<<endl<<endl;
	cerr<<"C:"<<endl
	    <<M<<" "<<N<<" M"<<endl;
	for (size_t i=0; i<M; ++i)
		for (size_t j=0; j<N; ++j)
			cerr<<i+1<<" "<<j+1<<" "<<((int) *(Cter+i*N+j) )<<endl;
	cerr<<"0 0 0"<<endl;

	FFLAS::fflas_delete( A);
	FFLAS::fflas_delete( B);
	FFLAS::fflas_delete( C);
	FFLAS::fflas_delete( Cbis);
	FFLAS::fflas_delete( Cter);
}














back to top