In various numerical problems, the issue of matrix/vector multiplications arises. This can be implemented in different ways in python. Here we will compare, numpy einsum, numpy multiply, numba and c implementations and try to find out which is faster given a particular configuration and how each scales with number of nodes with MPI.
Matrix Vector Multiplication:
Consider the following operation
\[\begin{equation}\label{eq:mult}\tag{1} R_{ijk}=\sum_{\ell=0}^N a_{i\ell jk}v_{\ell jk}+b_{ijk} \end{equation}\]which can be implemented in the compact notation of numpy’s einsum routine as:
def mult1(v,a,b,res):
res[:]=np.einsum('jikl,ikl->jkl',a,v)+b
The same can be implemented using direct multiplications as:
def mult2(v,a,b,res):
res[:]=np.sum(np.multiply(a,u),1)+b
Under numba
@njit
def mult3(v,a,b,res):
for j in range(v.shape[0]):
for lx in range(v.shape[1]):
for ly in range(v.shape[2]):
res[j,lx,ly]=a[j,0,lx,ly]*v[0,lx,ly]+a[j,1,lx,ly]*v[1,lx,ly]+b[j,lx,ly]
and finally under C
#include <complex.h>
void multc(complex *u, complex *a, complex *b, complex *res, unsigned int N){
int l,j;
for (j=0;j<2;j++){
for (l=0;l<N;l++){
res[N*j+l]=a[N*2*j+N*0+l]*u[N*0+l]+a[N*2*j+N*1+l]*u[N*1+l]+b[N*j+l];
}
}
}
which we compile using the following Makefile:
CFLAGS=-O3 -fopenmp -g -Wall -ansi -DNDEBUG -fomit-frame-pointer \
-fstrict-aliasing -ffast-math -msse2 -mfpmath=sse -march=native
libmult.so: mult.o
gcc $(CFLAGS) -shared -o libmult.so mult.o
mult.o: mult.c
gcc $(CFLAGS) -fpic -c mult.c
when we perform a set of mpi runs with an input matrix of global size 2,1024,2047 in order to compute running speeds. We get the following chart:
These are the averge time in seconds that it takes for the root process to compute 10 such multiplications (with always the same matrices) using different implementations, with error bars indicationg standard deviation among these 10 results.
Details of the benchmark can be found here.
Note that when we use fastmath=True for numba, in general we gain a bit more speedup. This makes numba potentially faster than C if written correctly.