mpi4py-fft is a parallel fft library for python which uses serial fftw (so not the fftw3_mpi directly) via mpi4py in order to achieve two and three dimensional parallel ffts. It is well documented and therefore we do not need to introduce how it is used here.

It is based on the idea that we can compute 2D ffts by computing 1d ffts in y direction followed by a transpose (with mpi alltoall) followed by computing 1d transforms in the x direction. We can finally transpose the result again or not depending on if we can live with a transposed result.

Basic Usage:

The basic use of the parallel fft is described here. There is also a nice minimal example of a pseudo-spectral solver shown here

However we generally need to perform multiple (e.g. first 6 then 2) 2D ffts on a 3D array. This could be done using lists of course, but it also makes sense to do it using PFFT directly on a 3D array.

Multiple 2D PFFTs on a 3D array:

The basic interface can be used to perform multiple 2D transforms as follows:

test_pfft.py


from mpi4py import MPI
from mpi4py_fft import PFFT,newDistArray
import numpy as np

howmany=6
Nx,Ny=128,128
padx,pady=3/2,3/2
Npx,Npy=int(128*padx),int(128*pady)
comm=MPI.COMM_WORLD

pf=PFFT(comm,shape=(howmany,Nx,Ny),axes=(1,2), grid=[1,-1,1], padding=[1,1.5,1.5],collapse=False)
u=newDistArray(pf,forward_output=False)
uk=newDistArray(pf,forward_output=True)

  • here axes=(1,2) means we transform in the last two indices (x and y).
  • grid=[1,-1,1] means we split the input array (for the real to complex forward transform) in x direction.
  • padding=[1,1.5,1.5] means we pad using the 2/3 rule for the x and y directions.
  • Note that for some reason collapse = True gives the error below. So we need to use the default choice collapse=False, which I wrote explicitly to warn the reader.

File “/usr/lib/python3.8/site-packages/mpi4py_fft/libfft.py”, line 400, in init assert len(self.axes) == 1

  • The arrays u and uk in thie example have the (global) shapes (6,192,192) and (6,128,65) note that Npx=192 and Nyh=int(Ny/2+1)=65.

We can fill the input array and perform a forward transform. However note the caveat that the input array is actually supposed to be padded/backward transformed version of the output array. So if we perform a forward transform followed by a backward transform it is likely that we will get a smoothed version of the input array. If we use a relatively compact input array, this should not be an important issue.

n,x,y=np.meshgrid(np.arange(0,howmany),np.linspace(-1,1,Npx),np.linspace(-1,1,Npy),indexing='ij')
nl,xl,yl=n[u.local_slice()],x[u.local_slice()],y[u.local_slice()]
u[:]=np.sin(4*np.pi*(xl+2*yl))*np.exp(-xl**2/2/0.2-yl**2/2/0.1)*(nl-3)

Here is what this looks like in 2D when called with 4 processes:

pfft_ex1_in

We can compute the Fourier transform using

pf.forward(u,uk)

pfft_ex1_in

For some reason empty pixels appear in the plot between the parts of the array owned by different processes, this is either an issue about how we compute the locations or an artefact of how matplotlib pcolormesh handles the last points. In any case, it is most likely not real.

Note that as a function of n, we basically have a linearly increasing/decreasing form. We can use this to make sure that we do not compute the fft in the first index. Indeed both before and after the transform we have the simple linear form.

As discussed above, if we perform a forward transform followed by a backward transform, the difference will be nontrivial because of the padding and the fact that we start with the input (which is supposed to be padded but in fact is not):

pf.forward(u,uk)
pf.backward(uk,u)

if we plot the difference to the original array (0th index), we see this difference:

pfft_ex1_in

On the other hand if we perform the Fourier transform and back again and compare the difference to the case after the first transform.

pfft_ex1_in

the difference become comparable to machine precision.

Wavenumbers and FFT

It is clear from the figures above that the mpi4py_fft computes non centered ffts as is usually the case with fft routines. The corresponding wavenumbers can be written as follows:

kx=np.r_[0:int(Nx/2),-int(Nx/2):0]*dkx
#k = np.fft.fftfreq(Nx, 1./Nx).astype(int)
ky=np.r_[0:int(Ny/2+1)]*dky

Or in terms of indices which transform the array into a centered one, or a centered array into a non-centered one as:

inds=np.r_[int(Nx/2):Nx,0:int(Nx/2)]

So that for example we have: (with uk shaped as (n,Nx,Ny))

np.allclose(np.fft.fftshift(uk,1),uk[:,inds,:])
Out: True

if we somehow want to remove the Nyquist modes:

uk[:,:,-1]=0
uk[:,int(Nx/2),:]=0

in order to see how the padding works, we can try:

from mpi4py import MPI
from mpi4py_fft import PFFT,newDistArray, DistArray
import numpy as np
import matplotlib.pylab as plt

comm=MPI.COMM_WORLD

n,Nx,Ny=2,8,8
padx,pady=1.5,1.5
fft=PFFT(comm,shape=(n,Nx,Ny),axes=(1,2), grid=[1,-1,1], padding=[1,padx,pady],collapse=False)
fft_nopad=PFFT(comm,shape=(n,int(Nx*padx),int(Ny*pady)),axes=(1,2), grid=[1,-1,1])

uk=DistArray((2,Nx,int(Ny/2+1)),subcomm=(1,0,1),dtype=complex,alignment=2)
kx=DistArray((Nx,int(Ny/2+1)),subcomm=(0,1),dtype=float,alignment=1)
ky=DistArray((Nx,int(Ny/2+1)),subcomm=(0,1),dtype=float,alignment=1)
inds=np.r_[int(Nx/2):Nx,0:int(Nx/2)]

uk[0,:,:]=np.array([[inds[l]+1j*(m+uk.local_slice()[2].start) for m in range(uk.shape[2])] for l in range(uk.shape[1]) ])
uk[1,:,:]=np.array([[2*inds[l]+1j*(m+uk.local_slice()[2].start+1) for m in range(uk.shape[2])] for l in range(uk.shape[1]) ])

uk[:,:,-1]=0
uk[:,int(Nx/2),:]=0

u=newDistArray(fft,forward_output=False)
ukpad=newDistArray(fft_nopad,forward_output=True)

fft.backward(uk,u)
fft_nopad.forward(u,ukpad)

If we plot this, we get (with and withour the uk[:,:,-1]=0 and uk[:,int(Nx/2),:]=0 lines:

pfft_ukpad