/*
   bsort_pp.c - Parallel sorting algorithm based on bubblesort;
                instrumented for parallel profiling

   compile: mpicc -Wall -O -o bsort_pp bsort_pp.c
   run:     mpirun -np num_procs bsort_pp in_file out_file
*/

#include <stdio.h>
#include <stdlib.h>
#include <mpi.h>


/* swap entries in array v at positions i and j; used by bubblesort */
static inline /* this improves performance; Exercise: by how much? */
void swap(int * v, int i, int j)
{
  int t = v[i];
  v[i] = v[j];
  v[j] = t;
}


/* (bubble) sort array v; array is of length n */
void bubblesort(int * v, int n)
{
  int i, j;
  for (i = n-2; i >= 0; i--)
    for (j = 0; j <= i; j++)
      if (v[j] > v[j+1])
        swap(v, j, j+1);
}


/* merge two sorted arrays v1, v2 of lengths n1, n2, respectively */
int * merge(int * v1, int n1, int * v2, int n2)
{
  int * result = (int *)malloc((n1 + n2) * sizeof(int));
  int i = 0;
  int j = 0;
  int k;
  for (k = 0; k < n1 + n2; k++) {
    if (i >= n1) {
      result[k] = v2[j];
      j++;
    }
    else if (j >= n2) {
      result[k] = v1[i];
      i++;
    }
    else if (v1[i] < v2[j]) { // indices in bounds as i < n1 && j < n2
      result[k] = v1[i];
      i++;
    }
    else { // v2[j] <= v1[i]
      result[k] = v2[j];
      j++;
    }
  }
  return result;
}


int main(int argc, char ** argv)
{
  int n;
  int * data = NULL;
  int c, s;
  int * chunk;
  int o;
  int * other;
  int step;
  int p, id;
  MPI_Status status;
  FILE * file = NULL;
  int i;
  // for profiling
  double t_total, t_bcast, t_scatter, t_sort, t_merge_comm, t_merge_comp, t0;
  int n_merge_stages;

  if (argc!=3) {
    fprintf(stderr, "Usage: mpirun -np <num_procs> %s <in_file> <out_file>\n", argv[0]);
    exit(1);
  }

  MPI_Init(&argc, &argv);
  MPI_Comm_size(MPI_COMM_WORLD, &p);
  MPI_Comm_rank(MPI_COMM_WORLD, &id);

  if (id == 0) {
    // read size of data
    file = fopen(argv[1], "r");
    fscanf(file, "%d", &n);
    // compute chunk size
    c = n/p; if (n%p) c++;
    // read data from file
    data = (int *)malloc(p*c * sizeof(int));
    for (i = 0; i < n; i++)
      fscanf(file, "%d", &(data[i]));
    fclose(file);
    // pad data with 0 -- doesn't matter
    for (i = n; i < p*c; i++)
      data[i] = 0;
  }

  // start the total timer
  MPI_Barrier(MPI_COMM_WORLD);
  t_total = - MPI_Wtime();

  // broadcast size (and profile broadcast time)
  t_bcast = - MPI_Wtime();
  MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD);
  t_bcast += MPI_Wtime();

  // compute chunk size
  c = n/p; if (n%p) c++;

  // scatter data (and profile scatter time)
  chunk = (int *)malloc(c * sizeof(int));
  t_scatter = - MPI_Wtime();
  MPI_Scatter(data, c, MPI_INT, chunk, c, MPI_INT, 0, MPI_COMM_WORLD);
  t_scatter += MPI_Wtime();
  free(data);
  data = NULL;

  // compute size of own chunk and sort it (and profile sorting time)
  s = (n >= c * (id+1)) ? c : n - c * id;
  t_sort = - MPI_Wtime();
  bubblesort(chunk, s);
  t_sort += MPI_Wtime();

  // up to log_2 p merge steps
  t_merge_comm = 0;
  t_merge_comp = 0;
  n_merge_stages = 0;
  for (step = 1; step < p; step = 2*step) {
    if (id % (2*step)) {
      // id is no multiple of 2*step: send chunk to id-step and exit loop
      MPI_Send(chunk, s, MPI_INT, id-step, 0, MPI_COMM_WORLD);
      break;
    }
    // id is multiple of 2*step: merge in chunk from id+step (if it exists)
    if (id+step < p) {
      n_merge_stages++;
      // compute size of chunk to be received
      o = (n >= c * (id+2*step)) ? c * step : n - c * (id+step);
      // receive other chunk (and cumulatively profile communication time)
      other = (int *)malloc(o * sizeof(int));
      t0 = MPI_Wtime();
      MPI_Recv(other, o, MPI_INT, id+step, 0, MPI_COMM_WORLD, &status);
      t_merge_comm += MPI_Wtime() - t0;
      // merge (and cumulatively profile merge time)
      t0 = MPI_Wtime();
      data = merge(chunk, s, other, o);
      t_merge_comp += MPI_Wtime() - t0;
      free(chunk);
      free(other);
      chunk = data;
      s = s + o;
    }
  }

  // stop the total timer
  t_total += MPI_Wtime();

  // write sorted data to out file
  if (id == 0) {
    file = fopen(argv[2], "w");
    fprintf(file, "%d\n", s);   // assert (s == n)
    for (i = 0; i < s; i++)
      fprintf(file, "%d\n", chunk[i]);
    fclose(file);
  }

  // print timings
  MPI_Barrier(MPI_COMM_WORLD);
  printf("# n     p id  time[s] bcast[s] scat[s] sort[s] mrg_n mrg_comm[s] mrg_comp[s]\n");
  printf("%6d %2d %02d %8.3f %6.3f %8.3f %8.3f %3d %10.3f %11.3f\n", n, p, id, t_total, t_bcast, t_scatter, t_sort, n_merge_stages, t_merge_comm, t_merge_comp);
  MPI_Finalize();
  return 0;
}
