I detailed how to write a basic pseudo-spectral solver in 2D in python in my previous post. In this one, I will talk about how to make it faster using [Numba] and [pyfftw].

As can be seen in my speed benchmark for matrix multiplication, numba is indeed very fast.

Using numba an pyfftw

Using numba and pyfftw allows us to

  • Combine the ffts into what fftw calls many dft’s using its advanced interface.
  • Use multithreading both for fft’s and for vector matrix multiplication/addition operations.

I will discuss only the major differences to the previous formulation instead of going over everything again. You can check the full source code if you want to see how it all fits together.

In order to use numba and pyfftw we add the following lines to the imports

navsp2.py

from numba import njit,prange,set_num_threads
import pyfftw as pyfw

nthreads=16
set_num_threads(nthreads)
zeros=pyfw.empty_aligned

Note that we use pyfw.empty_aligned instead of np.zeros as the initialization routine for all the arrays. This makes the arrays properly aligned for fast fourier transforms.

When initializing the vectors and matrices, we can add the following lines

datk=zeros((4,Npx,int(Npy/2+1)),dtype=complex)
dat=datk.view(dtype=float)[:,:,:-2]
rdatk=datk[0,:,:]
rdat=rdatk.view(dtype=float)[:,:-2]

#Initializing pyfftw plans
fftw_dat4b = pyfw.FFTW(datk, dat, axes=(-2, -1),direction='FFTW_BACKWARD',normalise_idft=True,threads=nthreads)
fftw_dat1f = pyfw.FFTW(rdat,rdatk,axes=(-2, -1),direction='FFTW_FORWARD',normalise_idft=True,threads=nthreads)

Using views of the same array as above, allows us to perform in-place many c2r (complex to real) or r2c tranforms at once. Obviously, the in-place transform will overwrite the input array.

We also need three separate functions that we will write in numba in order to accelerate the RHS function.

As discussed earlier the workflow of the RHS function consists of:

  • Initialize the 4 components of the datk array as \(\big[i k_x \Phi _k, i k_y \Phi _k, -i k_x k^2\Phi _k, -i k_y k^2 \Phi _k\big]\) with padding.
  • Compute inverse Fourier transforms to obtain the dat array, as \(\big[\partial_x \Phi, \partial_y \Phi, \partial_x \nabla^2 \Phi, \partial_y \nabla^2\Phi\big]\).
  • Multiply the components of the dat matrix in order to form the convolution as the rdat matrix \(\partial_x\Phi\partial_y\nabla^2\Phi-\partial_y\Phi\partial_x\nabla^2\Phi\)
  • Compute the forward Fourier transform of the rdat to get rdatk, which is the nonlinear term without the additional 1/ksqr.
  • multiply and add everything together:
    • uk multiplied by Lnm representing the linear terms in the equation
    • rdatk multiplied by Nlm representing the nonlinear terms
    • Fnm representing forcing.

We can define a numba function that performs the first step as follows

#numba function that initializes the convolution arrays
@njit(fastmath=True,parallel=True)
def setdatk(v,d,kx,ky,ksqr):
    for i in prange(v.shape[0]):
        ip=i+int(2*i/v.shape[0])*(d.shape[1]-v.shape[0])
        for j in prange(v.shape[1]):
            d[0,ip,j]=1j*kx[i,j]*v[i,j]
            d[1,ip,j]=1j*ky[i,j]*v[i,j]
            d[2,ip,j]=-1j*kx[i,j]*ksqr[i,j]*v[i,j]
            d[3,ip,j]=-1j*ky[i,j]*ksqr[i,j]*v[i,j]

The third step, can be done using a convolution multiplier function, written in numba as

#numba convolution multiplier function
@njit(fastmath=True,parallel=True)
def multconv(d,rd):
    for i in prange(rd.shape[0]):
        for j in prange(rd.shape[1]):
            rd[i,j]=d[0,i,j]*d[3,i,j]-d[1,i,j]*d[2,i,j]

and finally the last step can be written in numba as

#numba matrix vector multiplication function
@njit(fastmath=True,parallel=True)
def mvecmult(v,a,b,c,d,res):
    for i in prange(v.shape[0]):
        ip=i+int(2*i/v.shape[0])*(b.shape[0]-v.shape[0])
        for j in prange(v.shape[1]):
            res[i,j]=v[i,j]*a[i,j]+b[ip,j]*c[i,j]+d[i,j]

Using these functions, we can formulate the right hand function in the form

#The ODE RHS function for 2D Navier-Stokes 
def f(t,y):
    vk=y.view(dtype=complex).reshape(uk.shape)
    datk.fill(0)
    setdatk(vk,datk,kx,ky,ksqr)
    fftw_dat4b()
    multconv(dat,rdat)
    fftw_dat1f()
    mvecmult(vk,Lnm,rdatk,Nlm,Fnm,dukdt)
    return dukdt.ravel().view(dtype=float)

which basically follows the work flow described above. This version of the solver is almost fully multithreaded as opposed to the numpy version. This version takes slightly over 4 hours to run on a 16 cores workstation with the same resolution as the pure python/numpy case (which took about 2 days).