The primary computation in a pseudo-spectral code is the convolution. Here we develop a python class that can compute parallel convolutions with arbitrary multiplier functions as described earlier using mpi4py-fft.

Formulation:

A basic convolution class can be written as:

pfft_conv.py

from mpi4py import MPI
from mpi4py_fft import PFFT,newDistArray, DistArray
import numpy as np
from numba import njit

#... mult_in, mult_out62 and hermitian_symmetrize are defined here ...#

class hwconv:
    def __init__(self,Nx,Ny,padx=1.5,pady=1.5,num_in=6, num_out=2,
                 fmultin=mult_in, fmultout=mult_out62,comm=MPI.COMM_WORLD):
        self.comm=comm
        self.num_in=num_in
        self.num_out=num_out
        self.nar=max(num_in,num_out)
        self.ffti=PFFT(comm,shape=(self.num_in,Nx,Ny),axes=(1,2), grid=[1,-1,1], padding=[1,1.5,1.5],collapse=False)
        self.ffto=PFFT(comm,shape=(self.num_out,Nx,Ny),axes=(1,2), grid=[1,-1,1], padding=[1,1.5,1.5],collapse=False)
        self.datk=newDistArray(self.ffti,forward_output=True)
        self.dat=newDistArray(self.ffti,forward_output=False)
        lkx=np.r_[0:int(Nx/2),-int(Nx/2):0]
        lky=np.r_[0:int(Ny/2+1)]
        self.kx=DistArray((Nx,int(Ny/2+1)),subcomm=(1,0),dtype=float,alignment=0)
        self.ky=DistArray((Nx,int(Ny/2+1)),subcomm=(1,0),dtype=float,alignment=0)
        self.kx[:],self.ky[:]=np.meshgrid(lkx[self.kx.local_slice()[0]],lky[self.ky.local_slice()[1]],indexing='ij')
        self.ksqr=self.kx**2+self.ky**2
        self.fmultin=fmultin
        self.fmultout=fmultout

    def convolve(self,u):
        hermitian_symmetrize(u)
        if(u.local_slice()[2].stop==u.global_shape[2]):
            u[:,:,-1]=0
        u[:,int(self.Nx/2),:]=0
        self.fmultin(u,self.datk,self.kx,self.ky,self.ksqr)
        self.ffti.backward(self.datk,self.dat)
        self.fmultout(self.dat)
        self.ffto.forward(self.dat[:self.num_out,],self.datk[:self.num_out,])
        if(self.datk.local_slice()[2].stop==self.datk.global_shape[2]):
            self.datk[:,:,-1]=0
        self.datk[:,int(self.Nx/2),:]=0
        return self.datk[:self.num_out,]
  • The idea is to compute num_out number of convolutions using num_in number of initial input fields.
  • The forms of the convolution terms as sums or differences of real space fields are computed in the function fmultout provided by the user.
  • The actual 6 input vectors for the convolution are computed by the function fmultin before the fourier transforms using the input vectors (which are usually less than 6) and combinations of kx,ky and ksqr.

Usage:

We can use this to compute the same convolution of earlier discussions as follows:

Nx,Ny=8,8
h=hwconv(Nx,Ny)
uk=DistArray((2,Nx,int(Ny/2+1)),subcomm=(1,1,0),dtype=complex,alignment=1)
vk=DistArray((2,Nx,int(Ny/2+1)),subcomm=(1,1,0),dtype=complex,alignment=1)
kx=DistArray((Nx,int(Ny/2+1)),subcomm=(1,0),dtype=float,alignment=0)
ky=DistArray((Nx,int(Ny/2+1)),subcomm=(1,0),dtype=float,alignment=0)
inds=np.r_[int(Nx/2):Nx,0:int(Nx/2)]
uk[0,:,:]=np.array([[inds[l-1]+1j*m for m in np.r_[uk.local_slice()[2]]] for l in np.r_[uk.local_slice()[1]]])
uk[1,:,:]=np.array([[2*inds[l-1]+1j*(m+1) for m in np.r_[uk.local_slice()[2]]] for l in np.r_[uk.local_slice()[1]]])
vk[:]=h.convolve(uk)

The result is:

[[[    0.   +0.j -1500. -224.j -3936. -896.j -4212. +252.j     0.   +0.j]
  [ 1260.   -0.j  -104. -160.j -2529.-1017.j -2856.-1084.j     0.   +0.j]
  [ 2880.   -0.j  2206. -118.j  -272. -576.j -1326. -886.j     0.   +0.j]
  [ 3240.   -0.j  3360. -404.j  1308. -278.j   288.   +0.j     0.   +0.j]
  [    0.   +0.j     0.   +0.j     0.   +0.j     0.   +0.j     0.   +0.j]
  [ 3240.   +0.j   504. -972.j  -384.-1086.j  -648. -648.j     0.   +0.j]
  [ 2880.   +0.j   718. -282.j  -432. -224.j -1050. +356.j     0.   +0.j]
  [ 1260.   +0.j     8.  -16.j -1401.   +3.j -2040.+1076.j     0.   +0.j]]

 [[   -0.   +0.j   -12.  -76.j   -48.  -40.j  -108. +192.j     0.   +0.j]
  [    0. +288.j    15. +107.j   -12.  +80.j   -84. +252.j     0.   +0.j]
  [    0. +480.j    33. +212.j    12. +140.j   -68. +284.j     0.   +0.j]
  [    0. +576.j    42. +254.j    24. +160.j   -60. +300.j     0.   +0.j]
  [    0.   +0.j     0.   +0.j     0.   +0.j     0.   +0.j     0.   +0.j]
  [    0. -576.j   -54. -402.j   -72. -272.j   -48. -144.j     0.   +0.j]
  [    0. -480.j   -47. -330.j   -68. -216.j   -58.  -88.j     0.   +0.j]
  [    0. -288.j   -33. -217.j   -60. -136.j   -78.  +18.j     0.   +0.j]]]

which is basically the same as in my previous post but shifted and with a non-compact form.