Even though there are many different ODE solvers in python. It seems that a simple parallel adaptive time step solver that can use mpi is lacking. One either has to use a huge framework, or fallback to fixed time step solvers.
CVODE is a an ODE solver developed by LLNL as part of the SUNDIALS SUite of Nonlinear and DIfferential/ALgebraic Equation Solvers. It is c library that uses various vector formats among which is the NVECTOR_PARALLEL module that can be used for writing an mpi based ode solver.
Minimalist Python Wrapper for CVODE
MPCVode is a really minimalistic python wrapper to CVODE that is intended to be used by modifiying the C code as needed.
The primary C initialization routines look like:
#include <nvector/nvector_parallel.h>
#include <mpi.h>
typedef struct mpcv_pars{
double *y, *dydt;
double t0,t;
int N,Nloc;
MPI_Comm comm;
void *solver;
N_Vector uv;
void (*fnpy)(double, double *, double *);
}mpcv_pars;
void init_solver(int N,int Nloc, double *y, double *dydt, double t0,
void (*fnpy)(double,double *,double *),
double atol, double rtol, int mxsteps);
void integrate_to(double tnext, double *t, int *state);
These two functions are then defined as:
#include <stdlib.h>
#include <cvode/cvode.h>
#include <nvector/nvector_parallel.h>
#include <sunnonlinsol/sunnonlinsol_fixedpoint.h>
#include <sundials/sundials_types.h>
#include <mpi.h>
#include "mpcvode.h"
mpcv_pars *p_glob;
static int fnmpcvode(realtype t, N_Vector y, N_Vector dydt, void *fdata){
mpcv_pars *p=(mpcv_pars*)fdata;
p->fnpy(t,NV_DATA_P(y),NV_DATA_P(dydt));
return 0;
}
void init_solver(int N,int Nloc, double *y, double *dydt,
double t0, void (*fnpy)(double,double *,double *),
double atol, double rtol, int mxsteps){
SUNNonlinearSolver NLS;
int state;
mpcv_pars *p;
p=malloc(sizeof(mpcv_pars));
p->N=N;
p->Nloc=Nloc;
p->comm=MPI_COMM_WORLD;
p->y=y;
p->t0=t0;
p->dydt=dydt;
p->fnpy=fnpy;
p->uv=N_VMake_Parallel(p->comm,Nloc,N,y);
p->solver=CVodeCreate(CV_ADAMS);
state = CVodeSetUserData(p->solver, p);
state = CVodeSetMaxNumSteps(p->solver, mxsteps);
state = CVodeInit(p->solver, fnmpcvode,t0,p->uv);
state = CVodeSStolerances(p->solver, rtol, atol);
NLS = SUNNonlinSol_FixedPoint(p->uv, 0);
state = CVodeSetNonlinearSolver(p->solver, NLS);
p_glob=p;
};
void integrate_to(double tnext, double *t, int *state){
mpcv_pars *p=p_glob;
*state=CVode(p->solver, tnext, p->uv, &(p->t), CV_NORMAL);
*t=p->t;
}
Following the instructions for calling C functions from python that can call back python functions, we can create a shared library and import it to python. It is probably better to write an actual python class as a wrapper:
import os
import numpy as np
from ctypes import cdll,CFUNCTYPE,POINTER,c_double,c_int,byref
class mpcvode:
def __init__(self, fn, y, dydt, t0, t1, **kwargs):
self.libmpcvod = cdll.LoadLibrary(os.path.dirname(__file__)+'/libmpcvode.so')
self.fnpytype=CFUNCTYPE(None, c_double, POINTER(c_double), POINTER(c_double))
self.local_shape=y.shape
self.global_shape=y.global_shape
self.global_size = int(np.prod(y.global_shape)*
y.dtype.itemsize/np.dtype(float).itemsize)
self.local_size = int(y.size*y.dtype.itemsize/np.dtype(float).itemsize)
self.fn=fn
self.kwargs=kwargs
self.comm=y.comm
self.t0=t0
self.t1=t1
self.y=y
self.dydt=dydt
self.t=t0
self.state=0
self.atol = kwargs.get('atol',1e-8)
self.rtol = kwargs.get('rtol',1e-6)
self.mxsteps = int(kwargs.get('mxsteps',10000))
self.fnmpcvod=self.fnpytype(lambda x,y,z : self.fnforw(x,y,z))
self.fn
self.libmpcvod.init_solver(self.global_size,self.local_size,
self.y.ctypes.data_as(POINTER(c_double)),
self.dydt.ctypes.data_as(POINTER(c_double)),
c_double(self.t0),self.fnmpcvod,c_double(self.atol),
c_double(self.rtol),c_int(self.mxsteps));
def fnforw(self,t,y,dydt):
u=np.ctypeslib.as_array(y,(self.local_size,)).view(
dtype=complex).reshape(self.local_shape)
dudt=np.ctypeslib.as_array(dydt,(self.local_size,)).view(
dtype=complex).reshape(self.local_shape)
self.fn(t,u,dudt)
def integrate_to(self,tnext):
t=c_double()
state=c_int()
self.libmpcvod.integrate_to(c_double(tnext),byref(t),byref(state))
self.t=t.value
self.state=state.value
def successful(self):
return self.state==0
A Simple Example:
While an mpi example is usually a complicated beast, we can give a very simple one here.
from mpi4py import MPI
import numpy as np
from mpcvode import mpcvode
import matplotlib.pylab as plt
def splitmpi(shape,rank,size,axis=-1,Nsp=0):
sp=list(shape)
if(Nsp==0):
Nsp=sp[axis]
nperpe = int(Nsp/size)
nrem = Nsp - size*nperpe
n = nperpe+(rank < nrem)
start = rank*nperpe+min(rank,nrem)
off=np.zeros(len(sp),dtype=int)
sp[axis]=n
off[axis]=start
return sp,off
class distarray(np.ndarray):
def __new__(self,shape,dtype=float,buffer=None,
offset=0,strides=None,order=None,
axis=-1,Nsp=0,comm=MPI.COMM_WORLD):
dims=len(shape)
locshape,loc0=splitmpi(shape, comm.rank, comm.size, axis, Nsp)
if(buffer==None):
buffer=np.zeros(locshape,dtype)
else:
if(dtype!=buffer.dtype):
print("dtype!=buffer.dtype, ignoring dtype argument")
dtype=buffer.dtype
obj=super(distarray, self).__new__(self,locshape,dtype,buffer,offset,strides,order)
obj.loc0=loc0
obj.global_shape=shape
obj.local_slice = tuple([slice(loc0[l],loc0[l]+locshape[l],None) for l in range(dims)])
obj.comm=comm
return obj
phi=distarray((4,4),dtype=complex,comm=MPI.COMM_WORLD)
dphidt=distarray((4,4),dtype=complex,comm=MPI.COMM_WORLD)
gam=0.1
phi[:,:]=[ [i+j for j in np.r_[phi.local_slice[1]] ]
for i in np.r_[phi.local_slice[0]] ]
phi0=phi.copy()
def fntest(t,y,dydt):
dydt[:,:] = gam*y[:,:]
mpcv=mpcvode(fntest,phi,dphidt,0.0,100.0,atol=1e-12,rtol=1e-8)
t=np.arange(10.0)
z=np.zeros(10,dtype=complex)
z[0]=phi[0,0]
for l in range(1,t.shape[0]):
mpcv.integrate_to(t[l])
z[l]=mpcv.y[0,0]
plt.plot(t,z.real,'x',t,z[0].real*np.exp(0.1*t),'--')
plt.legend(['numerical solution',str(z[0].real)+'*exp(0.1*t)'])
plt.show()
Which we can run using something like:
mpirun -np 4 python ex1.py
The resulting figures show the time evolution of the zeroth element of each process.