//
// pit_MP_MPI.C
//

#include "pit_system.h"
#include  <fstream.h>
#include  <string.h>
#include "mpi.h"

MPI_Status status;

static int local_proc_id;
static int total_proc_count;

// Hmm, may be better to have these local here
// and not in particle_system class??

void particle_system::MP_distribute_BH_data()
{
    if (myprocid == 0){
	for(int i = 0;i < nbh; i++){
	    bhp[i] = pb[bhlocs[i]];
	}
    }
    MPI_Bcast(bhp,sizeof(particle)*nbh,MPI_BYTE,0,MPI_COMM_WORLD);
}


void MP_copyparams(real& eta, real& eps, real& dtout,
		   real& dt_snap, real& delta_t, int& nbh,
		   int& grape_nclusters,
		   int& grape_firstcluster )
{
    MPI_Bcast(&eta,1,MPI_DOUBLE,0,MPI_COMM_WORLD);
    MPI_Bcast(&eps,1,MPI_DOUBLE,0,MPI_COMM_WORLD);
    MPI_Bcast(&dtout,1,MPI_DOUBLE,0,MPI_COMM_WORLD);
    MPI_Bcast(&dt_snap,1,MPI_DOUBLE,0,MPI_COMM_WORLD);
    MPI_Bcast(&delta_t,1,MPI_DOUBLE,0,MPI_COMM_WORLD);
    MPI_Bcast(&nbh,1,MPI_INT,0,MPI_COMM_WORLD);
    MPI_Bcast(&grape_nclusters,1,MPI_INT,0,MPI_COMM_WORLD);
    MPI_Bcast(&grape_firstcluster,1,MPI_INT,0,MPI_COMM_WORLD);
    PRC(eta);     PRC(eps);    PRC(dtout); PRC(nbh);
    PRC(grape_nclusters);
    PRL(grape_firstcluster);
}
    

void particle_system::MP_initialize(int argc,char *argv[])
{
    int  namelen;
    int myid, numprocs;
    char processor_name[MPI_MAX_PROCESSOR_NAME];
    MPI_Init(&argc,&argv);
    MPI_Comm_size(MPI_COMM_WORLD,&numprocs);
    MPI_Comm_rank(MPI_COMM_WORLD,&myid);
    MPI_Get_processor_name(processor_name,&namelen);
    cerr << "Initialize:Myid = " << myid
	 <<  " Myname = " << processor_name
	 << " Nprocs = " << numprocs <<endl;
    nprocessors = total_proc_count = numprocs;
    myprocid = local_proc_id = myid;
    MPI_Barrier(MPI_COMM_WORLD);
}

void MP_convert_snap_name(bool& flag, char * name)
{
    char work[255];
    int flag_copy = flag;
    if (total_proc_count == 1) return;
    MPI_Bcast(&flag_copy,1,MPI_INT,0,MPI_COMM_WORLD);
    MPI_Barrier(MPI_COMM_WORLD);
    flag = flag_copy;
    if (flag){
	MPI_Bcast(name,254,MPI_CHAR,0,MPI_COMM_WORLD);
	sprintf(work,"%sMP%d-%d",name,total_proc_count,local_proc_id);
	strncpy(name, work, 254);
    }
}

class pvajp{
public:
    vector pos;
    vector vel;
    vector acc;
    vector jerk;
    real pot;
};

void MP_add_acc_and_jerk_for_list_from_other_host(particle* pb,
						  int nbody,
						  int nbh,
						  node_time* nt, 
						  int  n_next,
						  real eps2,
						  int nprocs,
						  int myid)
{
    
    real sys_t = nt[0].next_time;
    //    if (myid == 0)
    //	cerr << "Enter other host" <<nprocs << " " <<myid
    //	     << " time= " << sys_t << endl;
    int i,j;
    if (nprocs == 1) return;
    static pvajp * psource = NULL;
    static pvajp * pdest = NULL;
    particle pcopy;
    int nsend, nreceive;

    if (psource == NULL){
	psource = new pvajp[nbody+100];
	pdest = new pvajp[nbody+100];
    }
    //
    // collect contributions from forces on other processors
    //
    pvajp *pvptr = psource;
    for(i = 0; i <n_next;i++){
	particle *bi = nt[i].pptr;
	pvptr->pos = bi->get_pred_pos();
	pvptr->vel = bi->get_pred_vel();
	pvptr->acc = bi->get_acc();
	pvptr->jerk = bi->get_jerk();
	pvptr->pot = bi->get_pot();
	pvptr ++;
    }
    nsend = n_next;
    
    for(int ip = 0; ip<nprocs; ip++){
	// first transfer nsend....
	//	PRL(nsend);


	if ((myid %2) == 0)
	    MPI_Send( &nsend, 1, MPI_INT, (myid+1)%nprocs,ip*10 , MPI_COMM_WORLD);
	MPI_Recv( &nreceive, 1, MPI_INT, (myid+nprocs-1)%nprocs,ip*10 ,
		  MPI_COMM_WORLD,&status);
	if (myid %2)
	    MPI_Send( &nsend, 1, MPI_INT, (myid+1)%nprocs,ip*10 , MPI_COMM_WORLD);
	//	PRL(nreceive);
	//	cerr << "Send and receive end "<<endl;
	if ((ip == nprocs - 1)&&(n_next != nreceive)){
	    // after one ring rotation, something went wrong...
	    cerr << "calculate on other host internal error ";
	    PRC(ip); PRC(n_next); PRL(nreceive);
	    MPI_Abort(MPI_COMM_WORLD, -1);
	}
	// then transfer particle data and perform force calculation
	//	cerr << "Data send/receive start "<<endl;

	if ((myid %2) == 0)
	    if (nsend > 0) MPI_Send(psource, 13*nsend, MPI_DOUBLE,
				    (myid+1)%nprocs,ip*10+1 , MPI_COMM_WORLD);
	if (nreceive > 0) MPI_Recv(pdest, 13*nreceive, MPI_DOUBLE,
				   (myid+nprocs-1)%nprocs,ip*10+1 ,
				   MPI_COMM_WORLD,&status);
	if ((myid %2) != 0)
	    if (nsend > 0) MPI_Send(psource, 13*nsend, MPI_DOUBLE,
				    (myid+1)%nprocs,ip*10+1 , MPI_COMM_WORLD);
	//	cerr << "Data send/receive end "<<endl;
	if ((nreceive > 0)&&(ip<nprocs-1)){
	    pvajp *pvptr = pdest;
	    for(i = 0; i <nreceive;i++){
		pcopy.set_pred_pos(pvptr->pos);
		pcopy.set_pred_vel(pvptr->vel);
		pcopy.set_acc(pvptr->acc);
		pcopy.set_jerk(pvptr->jerk);
		pcopy.set_pot(pvptr->pot);
		particle *bj = pb;
		if (nbh > 0 ) {
		    cerr << "sorry, MPI host-only version with BH is not ready"<<endl;
		    MPI_Abort(MPI_COMM_WORLD, -1);
		}
		for(j=0;j<nbody;j++,bj++){
		    real epstmp = eps2;
		    pcopy.accumulate_acc_and_jerk(bj,epstmp);
		}
		pvptr->acc= pcopy.get_acc();
		pvptr->jerk=pcopy.get_jerk();
		pvptr->pot=pcopy.get_pot();
		pvptr++;
	    }
	}
	if (ip==nprocs-1){
	    pvajp *pvptr = pdest;
	    for(i = 0; i <n_next;i++){
		particle *bi = nt[i].pptr;
		bi->set_acc(pvptr->acc);
		bi->set_jerk(pvptr->jerk);
		bi->set_pot(pvptr->pot);
		pvptr++;
	    }
	}else if (nreceive>0){
	    for(i = 0; i <nreceive;i++){
		*(psource+i) = *(pdest+i);
	    }
	}
	nsend = nreceive;
	//	MPI_Barrier(MPI_COMM_WORLD);
    }
}

void MP_collect_cmterms(vector& pos,vector& vel,real& mass)
{
    MPI_Barrier(MPI_COMM_WORLD);
    real source[7];
    real dest[7];
    for(int i = 0;i<3;i++){
	source[i]= pos[i];
	source[i+3]= vel[i];
    }
    source[6]=mass;
    MPI_Reduce(source,dest,7, MPI_DOUBLE, MPI_SUM,0,MPI_COMM_WORLD);
    for(int i = 0;i<3;i++){
	pos[i] = dest[i];
	vel[i] = dest[i+3];
    }
    mass = dest[6];
    if (local_proc_id == 0) cerr << "Exit collect_cmterms\n";
}
void MP_collect_energies(real& e1,
			 real& e2,
			 real& e3)
{
    MPI_Barrier(MPI_COMM_WORLD);
    real source[3];
    real dest[3];
    source[0]=e1;
    source[1]=e2;
    source[2]=e3;
    MPI_Reduce(source,dest,3, MPI_DOUBLE, MPI_SUM,0,MPI_COMM_WORLD);
    e1 = dest[0];
    e2 = dest[1];
    e3 = dest[2];
}
void particle_system::MP_collect_integration_stats()
{
    
    int sum;
    MPI_Reduce(&step_count,&MP_step_count,1, MPI_DOUBLE,MPI_SUM,0,MPI_COMM_WORLD); 
}
void particle_system::MP_print_integration_stats(ostream &s)
{
    real cpu = cpu_time();
    real etime = wall_time();
    if (myprocid != 0){
	MPI_Send( &step_count, 1, MPI_DOUBLE, 0,myprocid*13 , MPI_COMM_WORLD);
	MPI_Send( &cpu, 1, MPI_DOUBLE, 0,myprocid*13+1 , MPI_COMM_WORLD);
	MPI_Send( &etime, 1, MPI_DOUBLE, 0,myprocid*13+2 , MPI_COMM_WORLD);
    }else{
	int i;
	real steps;
	real local_cpu;
	real local_etime;
	for(i=0;i<nprocessors; i++){
	    if (i==0){
		steps = step_count;
		local_cpu = cpu;
		local_etime = etime;
	    }else{
		MPI_Recv( &steps, 1, MPI_DOUBLE, i,i*13 ,MPI_COMM_WORLD,&status);
		MPI_Recv( &local_cpu, 1, MPI_DOUBLE, i,i*13+1 ,MPI_COMM_WORLD,&status);
		MPI_Recv( &local_etime, 1, MPI_DOUBLE, i,i*13+2 ,MPI_COMM_WORLD,&status);
	    }
	    s << "steps on proc "<<i << " = " <<steps
	      << " CPU time " << local_cpu 
	      << " Wallclock time " << local_etime <<endl;
	}
    }
}

void MP_adjust_current_time_and_nnext(int& n_next, real& ttmp)
{
    real global_tmin = ttmp;
    //    if (local_proc_id == 0) cerr << "Enter adjust time\n";
    MPI_Barrier(MPI_COMM_WORLD);
    //    if (local_proc_id == 0) cerr << "All time end\n";
    MPI_Allreduce(&ttmp, &global_tmin,1, MPI_DOUBLE, MPI_MIN,MPI_COMM_WORLD);
    //    if (local_proc_id == 0) cerr << "Allreduce end\n";
    if (ttmp > global_tmin){
	ttmp = global_tmin;
	n_next = 0;
    }
    //    if (local_proc_id == 0) cerr << "Exit adjust time\n";
}
int MP_intmax(int localval)
{
    int globalval = localval;
    //    if (local_proc_id == 0) cerr << "All int end\n";
    MPI_Allreduce(&localval, &globalval,1, MPI_INT, MPI_MAX,MPI_COMM_WORLD);
    return globalval;
}


void  MP_sync()
{
    //    if (local_proc_id == 0) cerr << "Enter MP_SYNC\n";
    MPI_Barrier(MPI_COMM_WORLD);
}

int MP_get_grape_error(int error)
{
    int global_error;
    //    if (local_proc_id == 0) cerr << "All gerror end\n";
    MPI_Allreduce(&error, &global_error,1, MPI_INT, MPI_MAX,MPI_COMM_WORLD);
    return global_error;
}


int MP_get_grape_id()
{
    int myid_in_mybox=0;
    int namelen;
    char myname[MPI_MAX_PROCESSOR_NAME];
    char othername[MPI_MAX_PROCESSOR_NAME];
    MPI_Get_processor_name(myname,&namelen);
    for(int i = 0;i<total_proc_count; i++){
	strncpy(othername, myname, MPI_MAX_PROCESSOR_NAME);
	MPI_Bcast(othername,MPI_MAX_PROCESSOR_NAME,MPI_CHAR,
		  i,MPI_COMM_WORLD);
	cerr << "myid = " <<local_proc_id << " " <<i << " "
	     <<myname  << " " <<othername << endl;
      	if (i < local_proc_id){
	    if (strcmp(myname, othername)==0)myid_in_mybox++;
	}
    }
    return myid_in_mybox;
}

void MP_end()
{
        MPI_Finalize();
}
