mpi4py-fft is a parallel fft library for python which uses serial fftw (so not the fftw3_mpi directly) via mpi4py in order to achieve two and three dimensional parallel ffts. It is well documented and therefore we do not need to introduce how it is used here.
It is based on the idea that we can compute 2D ffts by computing 1d ffts in y direction followed by a transpose (with mpi alltoall) followed by computing 1d transforms in the x direction. We can finally transpose the result again or not depending on if we can live with a transposed result.
Basic Usage:
The basic use of the parallel fft is described here. There is also a nice minimal example of a pseudo-spectral solver shown here
However we generally need to perform multiple (e.g. first 6 then 2) 2D ffts on a 3D array. This could be done using lists of course, but it also makes sense to do it using PFFT directly on a 3D array.
Multiple 2D PFFTs on a 3D array:
The basic interface can be used to perform multiple 2D transforms as follows:
from mpi4py import MPI
from mpi4py_fft import PFFT,newDistArray
import numpy as np
howmany=6
Nx,Ny=128,128
padx,pady=3/2,3/2
Npx,Npy=int(128*padx),int(128*pady)
comm=MPI.COMM_WORLD
pf=PFFT(comm,shape=(howmany,Nx,Ny),axes=(1,2), grid=[1,-1,1], padding=[1,1.5,1.5],collapse=False)
u=newDistArray(pf,forward_output=False)
uk=newDistArray(pf,forward_output=True)
- here axes=(1,2) means we transform in the last two indices (x and y).
- grid=[1,-1,1] means we split the input array (for the real to complex forward transform) in x direction.
- padding=[1,1.5,1.5] means we pad using the 2/3 rule for the x and y directions.
- Note that for some reason collapse = True gives the error below. So we need to use the default choice collapse=False, which I wrote explicitly to warn the reader.
File “/usr/lib/python3.8/site-packages/mpi4py_fft/libfft.py”, line 400, in init assert len(self.axes) == 1
- The arrays u and uk in thie example have the (global) shapes (6,192,192) and (6,128,65) note that Npx=192 and Nyh=int(Ny/2+1)=65.
We can fill the input array and perform a forward transform. However note the caveat that the input array is actually supposed to be padded/backward transformed version of the output array. So if we perform a forward transform followed by a backward transform it is likely that we will get a smoothed version of the input array. If we use a relatively compact input array, this should not be an important issue.
n,x,y=np.meshgrid(np.arange(0,howmany),np.linspace(-1,1,Npx),np.linspace(-1,1,Npy),indexing='ij')
nl,xl,yl=n[u.local_slice()],x[u.local_slice()],y[u.local_slice()]
u[:]=np.sin(4*np.pi*(xl+2*yl))*np.exp(-xl**2/2/0.2-yl**2/2/0.1)*(nl-3)
Here is what this looks like in 2D when called with 4 processes:
We can compute the Fourier transform using
pf.forward(u,uk)
For some reason empty pixels appear in the plot between the parts of the array owned by different processes, this is either an issue about how we compute the locations or an artefact of how matplotlib pcolormesh handles the last points. In any case, it is most likely not real.
Note that as a function of n, we basically have a linearly increasing/decreasing form. We can use this to make sure that we do not compute the fft in the first index. Indeed both before and after the transform we have the simple linear form.
As discussed above, if we perform a forward transform followed by a backward transform, the difference will be nontrivial because of the padding and the fact that we start with the input (which is supposed to be padded but in fact is not):
pf.forward(u,uk)
pf.backward(uk,u)
if we plot the difference to the original array (0th index), we see this difference:
On the other hand if we perform the Fourier transform and back again and compare the difference to the case after the first transform.
the difference become comparable to machine precision.
Wavenumbers and FFT
It is clear from the figures above that the mpi4py_fft computes non centered ffts as is usually the case with fft routines. The corresponding wavenumbers can be written as follows:
kx=np.r_[0:int(Nx/2),-int(Nx/2):0]*dkx
#k = np.fft.fftfreq(Nx, 1./Nx).astype(int)
ky=np.r_[0:int(Ny/2+1)]*dky
Or in terms of indices which transform the array into a centered one, or a centered array into a non-centered one as:
inds=np.r_[int(Nx/2):Nx,0:int(Nx/2)]
So that for example we have: (with uk shaped as (n,Nx,Ny))
np.allclose(np.fft.fftshift(uk,1),uk[:,inds,:])
Out: True
if we somehow want to remove the Nyquist modes:
uk[:,:,-1]=0
uk[:,int(Nx/2),:]=0
in order to see how the padding works, we can try:
from mpi4py import MPI
from mpi4py_fft import PFFT,newDistArray, DistArray
import numpy as np
import matplotlib.pylab as plt
comm=MPI.COMM_WORLD
n,Nx,Ny=2,8,8
padx,pady=1.5,1.5
fft=PFFT(comm,shape=(n,Nx,Ny),axes=(1,2), grid=[1,-1,1], padding=[1,padx,pady],collapse=False)
fft_nopad=PFFT(comm,shape=(n,int(Nx*padx),int(Ny*pady)),axes=(1,2), grid=[1,-1,1])
uk=DistArray((2,Nx,int(Ny/2+1)),subcomm=(1,0,1),dtype=complex,alignment=2)
kx=DistArray((Nx,int(Ny/2+1)),subcomm=(0,1),dtype=float,alignment=1)
ky=DistArray((Nx,int(Ny/2+1)),subcomm=(0,1),dtype=float,alignment=1)
inds=np.r_[int(Nx/2):Nx,0:int(Nx/2)]
uk[0,:,:]=np.array([[inds[l]+1j*(m+uk.local_slice()[2].start) for m in range(uk.shape[2])] for l in range(uk.shape[1]) ])
uk[1,:,:]=np.array([[2*inds[l]+1j*(m+uk.local_slice()[2].start+1) for m in range(uk.shape[2])] for l in range(uk.shape[1]) ])
uk[:,:,-1]=0
uk[:,int(Nx/2),:]=0
u=newDistArray(fft,forward_output=False)
ukpad=newDistArray(fft_nopad,forward_output=True)
fft.backward(uk,u)
fft_nopad.forward(u,ukpad)
If we plot this, we get (with and withour the uk[:,:,-1]=0 and uk[:,int(Nx/2),:]=0 lines: