# psimeq.py   solve simultaneous equations AX=Y 
# solve real linear equations for X where Y = A * X
# method: Gauss-Jordan elimination using maximum pivot
# usage:  X = psimeq(A,Y)
#    Translated to java by : Jon Squire , 26 March 2003
#    First written by Jon Squire December 1959 for IBM 650, translated to
#    other languages  e.g. Fortran converted to Ada converted to C
#    then converted to java,python, then made parallel November 2008

from numpy import array
from numpy.linalg import solve

def psimeq(A,Y):
  X=solve(A,Y)
  return X;

"""
  final int NP=4 # number of processors, modify to suit
  int n          # number of equations
  int m          # number of equations plus 1
  double B[][]   # working array, various processors on various parts
  int prow[]     # pivot row if not -1, variable number
  int frow[]     # finished row if not 0
  int k_col=0    # column being reduced, initial zero for "start"
  int k_row=0    # pivot row
  int srow[] = new int[NP+1] # starting row for each thread
  double smax_pivot[] = new double[NP] # max pivot each thread
  int spivot[] = new int[NP] # row for max pivot each thread
  CyclicBarrier barrier
  MyThread sthread[] = new MyThread[NP] # allocate all threads
 
  psimeq(final double A[][], final double Y[], double X[])
  {
    n=A.length
    m=n+1
    B=new double[n][m]  # working matrix
    if(A[0].length!=n || Y.length!=n || X.length!=n)
    {
      System.out.println("Error in Matrix.solve, inconsistent array sizes.")
    }
    solve(A, Y, X)
  }

  psimeq(int nn, final double A[][], final double Y[], double X[])
  {
    n=nn
    m=n+1
    B=new double[n][m]  # working matrix
    solve(A, Y, X)
  }

  psimeq(int nn, final double AA[], final double Y[], double X[])
  {
    n=nn
    m=n+1
    double A[][]=new double[n][n]  # reformat
    for(int i=0 i<n i++)
      for(int j=0 j<n j++)
	A[i][j] = AA[i*n+j]
    B=new double[n][m]  # working matrix
    solve(A, Y, X)
  }

  void solve(final double A[][], final double Y[], double X[])
  {
    int hold , I_pivot     # pivot indicies
    double pivot           # pivot element value
    int part=n/NP          # number of rows each thread processes

    # set up row range for each thread
    prow=new int[n]     # pivot row
    frow=new int[n]     # finished row
    srow[0]=0
    for(int i=1 i<NP i++)
    {
      srow[i]=srow[i-1]+part
    }
    srow[NP]=n # may be less than  part 

    # set up pivot row and finished row
    for(int k=0 k<n k++)
    {
      prow[k] = -1
      frow[k] = 0
    }

    for(int i=0 i<NP i++) sthread[i] = new MyThread(i,A,Y) # construct

    # threads build working data structures
    try
    {
      barrier = new CyclicBarrier(NP+1) # st
      for(int i=0 i<NP i++) sthread[i].start()
      barrier.await() # st
      for(k_col=0 k_col<n k_col++)
      {
        barrier = new CyclicBarrier(NP+1) # p1
        barrier.await() # p1
        update_row() 

        if(k_col>=n-1) break
        barrier = new CyclicBarrier(NP+1) # p2
        barrier.await() # p2
      } # finished solve
    }
    catch(InterruptedException e)
    {
      System.out.println("InterruptedException in master")
    }
    catch(BrokenBarrierException e)
    {
      System.out.println("BrokenBarrierException in master")
    }

    try
    {
      for(int i=0 i<NP i++) sthread[i].join()
    }
    catch(InterruptedException e)
    {
      System.out.println("InterruptedException in master")
    }

    #  build  X  for return, unscrambling rows
    for(int i=0 i<n i++)
    {
      X[i] = B[prow[i]][n]
    }
  } # end solve constructor, solution computed
  
  class MyThread extends Thread
  {
    int myid
    
    MyThread(int id, double A[][], double Y[])
    {
      myid = id
      # build working data structure
      for(int i=srow[myid] i<srow[myid+1] i++)
      {
        for(int j=0 j<n j++)
        {
          B[i][j] = A[i][j]
        }
        B[i][n] = Y[i]
      }
    } # end MyThread constructor
    
    public void run()
    {
      double pivot
      int rowi

      try
      {
        barrier.await() # st
        sleep(1)
	while(true)
	{
	  # find max pivot in my range
          # set smax_pivot[myid]
          # set spivot[myid]
          spivot[myid] = srow[myid]
          smax_pivot[myid] = 0.0 # singular caught elsewhere
          for(int i=srow[myid] i<srow[myid+1] i++)
          {
            if(frow[i]==0 && Math.abs(B[i][k_col]) > smax_pivot[myid])
            {
              spivot[myid] = i
              pivot = B[i][k_col]
              smax_pivot[myid] = Math.abs(pivot)
            }
          }

          barrier.await() # p1
          sleep(10)
          
          # reduce local rows
          # from   srow[myid] to  < srow[myid+1] 
          # using  k_row, k_col
 
          #  inner reduction loop
          for(int i=srow[myid] i<srow[myid+1] i++)
          {
	    if( i != k_row)
            {
              for(int j=k_col+1 j<n+1 j++)
              {
                B[i][j] = B[i][j] - B[i][k_col] *
                                    B[k_row][j]
              }
            }
          }
          #  finished inner reduction
          if(k_col>=n-1) break # this task may terminate
          barrier.await() # p2
          sleep(10)

	} # do next column
      } # end try
      catch(InterruptedException e)
      {
	System.out.println(myid+" InterruptedException")
      }
      catch(BrokenBarrierException e)
      {
	System.out.println(myid+" BrokenBarrierException")
      }
    } # end Run
  } # end class MyThread

  void update_row() # unify max pivots, update row[], 
  {
    int I_pivot, hold
    double abs_pivot

    # find max of threads
    abs_pivot = smax_pivot[0]
    I_pivot = spivot[0]
    for(int i=1 i<NP i++)
    {
      if(smax_pivot[i]>abs_pivot)
      {
	I_pivot = spivot[i]
	abs_pivot = smax_pivot[i]
      }
    }

    # have pivot, interchange row indicies
    k_row = I_pivot
    prow[k_col] = k_row
    frow[k_row] = 1

    # check for near singular
    if(abs_pivot < 1.0E-10)
    {
      for(int j=k_col+1 j<n+1 j++)
      {
        B[k_row][j] = 0.0
      }
      System.out.println("redundant row (singular) "+k_row)
    } # singular, set row to zero, including solution
    else
    {
      # reduce about pivot
      for(int j=k_col+1 j<n+1 j++)
      {
	B[k_row][j] = B[k_row][j] / B[k_row][k_col]
      }
    }
  } # end update_row

} # end class psimeq
"""