Even though there are many different ODE solvers in python. It seems that a simple parallel adaptive time step solver that can use mpi is lacking. One either has to use a huge framework, or fallback to fixed time step solvers.

CVODE is a an ODE solver developed by LLNL as part of the SUNDIALS SUite of Nonlinear and DIfferential/ALgebraic Equation Solvers. It is c library that uses various vector formats among which is the NVECTOR_PARALLEL module that can be used for writing an mpi based ode solver.

Minimalist Python Wrapper for CVODE

MPCVode is a really minimalistic python wrapper to CVODE that is intended to be used by modifiying the C code as needed.

The primary C initialization routines look like:

mpcvode.h

#include <nvector/nvector_parallel.h>
#include <mpi.h> 

typedef struct mpcv_pars{
  double *y, *dydt;
  double t0,t;
  int N,Nloc;
  MPI_Comm comm;
  void *solver;
  N_Vector uv;
  void (*fnpy)(double, double *, double *);
}mpcv_pars;

void init_solver(int N,int Nloc, double *y, double *dydt, double t0,
		 void (*fnpy)(double,double *,double *),
		 double atol, double rtol, int mxsteps);
void integrate_to(double tnext, double *t, int *state);

These two functions are then defined as:

mpcvode.c

#include <stdlib.h>
#include <cvode/cvode.h>
#include <nvector/nvector_parallel.h>
#include <sunnonlinsol/sunnonlinsol_fixedpoint.h>
#include <sundials/sundials_types.h>
#include <mpi.h>
#include "mpcvode.h"

mpcv_pars *p_glob;

static int fnmpcvode(realtype t, N_Vector y, N_Vector dydt, void *fdata){
  mpcv_pars *p=(mpcv_pars*)fdata;
  p->fnpy(t,NV_DATA_P(y),NV_DATA_P(dydt));
  return 0;
}

void init_solver(int N,int Nloc, double *y, double *dydt,
		 double t0, void (*fnpy)(double,double *,double *),
		 double atol, double rtol, int mxsteps){
  SUNNonlinearSolver NLS;
  int state;
  mpcv_pars *p;
  p=malloc(sizeof(mpcv_pars));
  p->N=N;
  p->Nloc=Nloc;
  p->comm=MPI_COMM_WORLD;
  p->y=y;
  p->t0=t0;
  p->dydt=dydt;
  p->fnpy=fnpy;
  p->uv=N_VMake_Parallel(p->comm,Nloc,N,y);
  p->solver=CVodeCreate(CV_ADAMS);
  state = CVodeSetUserData(p->solver, p);
  state = CVodeSetMaxNumSteps(p->solver, mxsteps);
  state = CVodeInit(p->solver, fnmpcvode,t0,p->uv);
  state = CVodeSStolerances(p->solver, rtol, atol);
  NLS = SUNNonlinSol_FixedPoint(p->uv, 0);
  state = CVodeSetNonlinearSolver(p->solver, NLS);
  p_glob=p;
};

void integrate_to(double tnext, double *t, int *state){
  mpcv_pars *p=p_glob;
  *state=CVode(p->solver, tnext, p->uv, &(p->t), CV_NORMAL);
  *t=p->t;
}

Following the instructions for calling C functions from python that can call back python functions, we can create a shared library and import it to python. It is probably better to write an actual python class as a wrapper:

import os
import numpy as np
from ctypes import cdll,CFUNCTYPE,POINTER,c_double,c_int,byref

class mpcvode:
    def __init__(self, fn, y, dydt, t0, t1, **kwargs):
        self.libmpcvod = cdll.LoadLibrary(os.path.dirname(__file__)+'/libmpcvode.so')
        self.fnpytype=CFUNCTYPE(None, c_double, POINTER(c_double), POINTER(c_double))
        self.local_shape=y.shape
        self.global_shape=y.global_shape
        self.global_size = int(np.prod(y.global_shape)*
                               y.dtype.itemsize/np.dtype(float).itemsize)
        self.local_size = int(y.size*y.dtype.itemsize/np.dtype(float).itemsize)
        self.fn=fn
        self.kwargs=kwargs
        self.comm=y.comm
        self.t0=t0
        self.t1=t1
        self.y=y
        self.dydt=dydt
        self.t=t0
        self.state=0
        self.atol = kwargs.get('atol',1e-8)
        self.rtol = kwargs.get('rtol',1e-6)
        self.mxsteps = int(kwargs.get('mxsteps',10000))
        self.fnmpcvod=self.fnpytype(lambda x,y,z : self.fnforw(x,y,z))
        self.fn
        self.libmpcvod.init_solver(self.global_size,self.local_size,
                                   self.y.ctypes.data_as(POINTER(c_double)),
                                   self.dydt.ctypes.data_as(POINTER(c_double)),
                                   c_double(self.t0),self.fnmpcvod,c_double(self.atol),
                                   c_double(self.rtol),c_int(self.mxsteps));

    def fnforw(self,t,y,dydt):
        u=np.ctypeslib.as_array(y,(self.local_size,)).view(
            dtype=complex).reshape(self.local_shape)
        dudt=np.ctypeslib.as_array(dydt,(self.local_size,)).view(
            dtype=complex).reshape(self.local_shape)
        self.fn(t,u,dudt)

    def integrate_to(self,tnext):
        t=c_double()
        state=c_int()
        self.libmpcvod.integrate_to(c_double(tnext),byref(t),byref(state))
        self.t=t.value
        self.state=state.value

    def successful(self):
        return self.state==0

A Simple Example:

While an mpi example is usually a complicated beast, we can give a very simple one here.

from mpi4py import MPI
import numpy as np
from mpcvode import mpcvode
import matplotlib.pylab as plt

def splitmpi(shape,rank,size,axis=-1,Nsp=0):
    sp=list(shape)
    if(Nsp==0):
        Nsp=sp[axis]
    nperpe = int(Nsp/size)
    nrem = Nsp - size*nperpe
    n = nperpe+(rank < nrem)
    start = rank*nperpe+min(rank,nrem)
    off=np.zeros(len(sp),dtype=int)
    sp[axis]=n
    off[axis]=start
    return sp,off

class distarray(np.ndarray):
    def __new__(self,shape,dtype=float,buffer=None,
                offset=0,strides=None,order=None,
                axis=-1,Nsp=0,comm=MPI.COMM_WORLD):
        dims=len(shape)
        locshape,loc0=splitmpi(shape, comm.rank, comm.size, axis, Nsp)
        if(buffer==None):
            buffer=np.zeros(locshape,dtype)
        else:
            if(dtype!=buffer.dtype):
                print("dtype!=buffer.dtype, ignoring dtype argument")
        dtype=buffer.dtype
        obj=super(distarray, self).__new__(self,locshape,dtype,buffer,offset,strides,order)
        obj.loc0=loc0
        obj.global_shape=shape
        obj.local_slice = tuple([slice(loc0[l],loc0[l]+locshape[l],None) for l in range(dims)])
        obj.comm=comm
        return obj
    
phi=distarray((4,4),dtype=complex,comm=MPI.COMM_WORLD)
dphidt=distarray((4,4),dtype=complex,comm=MPI.COMM_WORLD)
gam=0.1
phi[:,:]=[ [i+j for j in np.r_[phi.local_slice[1]] ] 
          for i in np.r_[phi.local_slice[0]] ]
phi0=phi.copy()
def fntest(t,y,dydt):
    dydt[:,:] = gam*y[:,:]

mpcv=mpcvode(fntest,phi,dphidt,0.0,100.0,atol=1e-12,rtol=1e-8)
t=np.arange(10.0)
z=np.zeros(10,dtype=complex)
z[0]=phi[0,0]
for l in range(1,t.shape[0]):
    mpcv.integrate_to(t[l])
    z[l]=mpcv.y[0,0]
plt.plot(t,z.real,'x',t,z[0].real*np.exp(0.1*t),'--')
plt.legend(['numerical solution',str(z[0].real)+'*exp(0.1*t)'])
plt.show()

Which we can run using something like:

mpirun -np 4 python ex1.py

The resulting figures show the time evolution of the zeroth element of each process.

mpcvode_ex