Strassen Multiplication

JAVA Code

import java.io.*;
class multiply
{
      int a[][]=new int[16][16];    
      int b[][]=new int[16][16];    
      int c[][]=new int[16][16];      
         
      void input(int a[][],int n)throws IOException
      {
          int i,j;
          for(i=0;i<n;i++)
       	  for(j=0;j<n;j++)
          {
              System.out.println("Input the element"+i+","+j);
              BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
              a[i][j]=Integer.parseInt(br.readLine());
           }
       }
        
       void display(int a[][],int n)
       {
          int i,j;
          for(i=0;i<n;i++)
       	  { 
              System.out.println("");
              for(j=0;j<n;j++)
          	System.out.print("\t"+a[i][j]);
          }
          System.out.println("\n");
       }

       void add(int a[][],int b[][],int c[][],int n)
       {
          int i,j;
          for(i=0;i<n;i++)
       	    for(j=0;j<n;j++)
         	c[i][j]=a[i][j]+b[i][j];
       }
 
       void sub(int a[][],int b[][],int c[][],int n)
       {
          int i,j;
          for(i=0;i<n;i++)
       	   for(j=0;j<n;j++)
         	c[i][j]=a[i][j]-b[i][j];
       }

       void submatrix(int a[][],int msb,int lsb,int b[][],int size)
       { 
          int i,j;
          for(i=0;i<size/2;i++)
       	   for(j=0;j<size/2;j++)
         	b[i][j]=a[(size/2*msb)+i][(size/2*lsb)+j];
       }
 
       void join(int x[][],int a[][],int b[][],int c[][],int d[][],int size)
       {
          int i,j;
          for(i=0;i<size/2;i++)
       	   for(j=0;j<size/2;j++)
         	x[i][j]=a[i][j];
          for(i=0;i<size/2;i++)
       	   for(j=size/2;j<size;j++)
         	x[i][j]=b[i][j-(size/2)];
          for(i=size/2;i<size;i++)
       	   for(j=0;j<size/2;j++)
         	x[i][j]=c[i-(size/2)][j];
          for(i=size/2;i<size;i++)
       	   for(j=size/2;j<size;j++)
         	x[i][j]=d[i-(size/2)][j-(size/2)];
       }

       void strasson(int x[][],int y[][],int z[][],int size)
       {
           int s1[][]=new int[16][16];int s2[][]=new int[16][16];int s3[][]=new int[16][16];
           int s4[][]=new int[16][16];int s5[][]=new int[16][16];int s6[][]=new int[16][16];
           int s7[][]=new int[16][16];
	   int a[][]=new int[16][16];int b[][]=new int[16][16];int c[][]=new int[16][16];
           int d[][]=new int[16][16];int e[][]=new int[16][16];int g[][]=new int[16][16];
           int f[][]=new int[16][16];
           int h[][]=new int[16][16];int i[][]=new int[16][16];int j[][]=new int[16][16];
           int k[][]=new int[16][16];int l[][]=new int[16][16];
           int m[][]=new int[16][16];int n[][]=new int[16][16];
           if(size>1)
           {
	   submatrix(x,0,0,a,size);
           submatrix(x,0,1,b,size);
           submatrix(x,1,0,c,size);
           submatrix(x,1,1,d,size);
           submatrix(y,0,0,e,size);
           submatrix(y,0,1,f,size);
           submatrix(y,1,0,g,size);
           submatrix(y,1,1,h,size);
           
               sub(f,h,m,size/2);          //f-h
               strasson(a,m,s1,size/2);    //s1
               add(a,b,m,size/2);          //a+b
               strasson(m,h,s2,size/2);
               add(c,d,m,size/2);          //c+d
               strasson(m,e,s3,size/2);
               sub(g,e,m,size/2);          //g-e
               strasson(d,m,s4,size/2);
               add(a,d,m,size/2);          //a+d
               add(e,h,n,size/2);          //e+h
               strasson(m,n,s5,size/2);
               sub(b,d,m,size/2);          //b-d
               add(g,h,n,size/2);          //g+h
               strasson(m,n,s6,size/2);
               sub(a,c,m,size/2);          //a-c
               add(e,f,n,size/2);          //e+f
               strasson(m,n,s7,size/2);
               add(s5,s6,i,size/2);
               add(i,s4,i,size/2);
               sub(i,s2,i,size/2);     //i
               add(s1,s2,j,size/2);    //j
               add(s3,s4,k,size/2);    //k
               add(s5,s1,l,size/2);
               sub(l,s7,l,size/2);
               sub(l,s3,l,size/2);     //l
               join(z,i,j,k,l,size);
	   }
		else 
			z[0][0]=x[0][0]*y[0][0];
       }
}
     class matrix{
	 public static void main(String args[])throws IOException
       {
           int n;
           System.out.println("Enter the order of matrix");
           BufferedReader br=new BufferedReader(new InputStreamReader(System.in));
           n=Integer.parseInt(br.readLine());
           multiply m=new multiply();
           m.input(m.a,n);
           m.display(m.a,n);
           m.input(m.b,n);
           m.display(m.b,n);
           m.strasson(m.a,m.b,m.c,n);
           System.out.println("The matrix after multiplication is:");
           m.display(m.c,n);
       }
}