/ Published in: C
Some macros are from the "Python Scripting for Computational Science" : http://folk.uio.no/hpl/scripting/
Expand |
Embed | Plain Text
Copy this code and paste it in your HTML
/* dtlsmodule.c */ #include <math.h> #include <stdio.h> #include <Python.h> #include "structmember.h" #include <numpy/arrayobject.h> /* ================================================================= MACROS */ #define QUOTE(s) # s /* turn s into string "s" */ #define NDIM_CHECK(a, expected_ndim, rt_error) \ if (PyArray_NDIM(a) != expected_ndim) { \ PyErr_Format(PyExc_ValueError, \ "%s array is %d-dimensional, but expected to be %d-dimensional", \ QUOTE(a), PyArray_NDIM(a), expected_ndim); \ return rt_error; \ } #define DIM_CHECK(a, dim, expected_length, rt_error) \ if (dim > PyArray_NDIM(a)) { \ PyErr_Format(PyExc_ValueError, \ "%s array has no %d dimension (max dim. is %d)", \ QUOTE(a), dim, PyArray_NDIM(a)); \ return rt_error; \ } \ if (PyArray_DIM(a, dim) != expected_length) { \ PyErr_Format(PyExc_ValueError, \ "%s array has wrong %d-dimension=%d (expected %d)", \ QUOTE(a), dim, PyArray_DIM(a, dim), expected_length); \ return rt_error; \ } #define TYPE_CHECK(a, tp, rt_error) \ if (PyArray_TYPE(a) != tp) { \ PyErr_Format(PyExc_TypeError, \ "%s array is not of correct type (%d)", QUOTE(a), tp); \ return rt_error; \ } #define CALLABLE_CHECK(func, rt_error) \ if (!PyCallable_Check(func)) { \ PyErr_Format(PyExc_TypeError, \ "%s is not a callable function", QUOTE(func)); \ return rt_error; \ } #define DIND1(a, i) *((double *) PyArray_GETPTR1(a, i)) #define DIND2(a, i, j) *((double *) PyArray_GETPTR2(a, i, j)) #define DIND3(a, i, j, k) *((double *) Py_Array_GETPTR3(a, i, j, k)) #define IIND1(a, i) *((int *) PyArray_GETPTR1(a, i)) #define IIND2(a, i, j) *((int *) PyArray_GETPTR2(a, i, j)) #define IIND3(a, i, j, k) *((int *) Py_Array_GETPTR3(a, i, j, k)) #define DEF_PYARRAY_GETTER(funcname, selftype, valname) \ static PyObject * \ funcname(selftype *self, void *closure) \ { \ Py_INCREF(self->valname); \ return PyArray_Return(self->valname); \ } #define DEF_PYARRAY_SETTER(funcname, selftype, valname, arraydim) \ static int \ funcname(selftype *self, PyObject *value, void *closure) \ { \ if (value == NULL) { \ PyErr_SetString( PyExc_TypeError, \ "Cannot delete the last attribute"); \ return -1; \ } \ if ( PyArray_Check(value) != 1 ){ \ PyErr_Format( PyExc_ValueError, \ "value is not of type numpy array"); \ return -1; \ } \ if ( PyArray_NDIM(value) != arraydim ){ \ PyErr_Format( PyExc_ValueError, \ "value array's dimension %d != arraydim", \ PyArray_NDIM(value)); \ return -1; \ } \ if ( PyArray_TYPE(value) != NPY_DOUBLE ){ \ PyErr_Format( PyExc_ValueError, \ "value array is not of type 'Python float'"); \ return -1; \ } \ Py_DECREF(self->valname); \ Py_INCREF(value); \ self->valname = (PyArrayObject *) value; \ return 0; \ } /* ========================================================== DTLSys struct */ typedef struct { PyObject_HEAD PyArrayObject *wt; PyArrayObject *bs; PyArrayObject *xt; } DTLSys; /* ============================================================ Declaration */ static void DTLSys_dealloc(DTLSys* self); static PyObject * DTLSys_new(PyTypeObject *type, PyObject *args, PyObject *kwds); static int DTLSys_init(DTLSys *self, PyObject *args, PyObject *kwds); static PyMemberDef DTLSys_members[] = { {NULL} /* Sentinel */ }; static PyObject * DTLSys_get_wt(DTLSys *self, void *closure); static int DTLSys_set_wt(DTLSys *self, PyObject *value, void *closure); static PyObject * DTLSys_get_bs(DTLSys *self, void *closure); static int DTLSys_set_bs(DTLSys *self, PyObject *value, void *closure); static PyObject * DTLSys_get_xt(DTLSys *self, void *closure); static int DTLSys_set_xt(DTLSys *self, PyObject *value, void *closure); static PyGetSetDef DTLSys_getseters[] = { {"wt", (getter)DTLSys_get_wt, (setter)DTLSys_set_wt, "Matrix", NULL}, {"bs", (getter)DTLSys_get_bs, (setter)DTLSys_set_bs, "Vector", NULL}, {"xt", (getter)DTLSys_get_xt, (setter)DTLSys_set_xt, "Vector time sequence", NULL}, {NULL} /* Sentinel */ }; static int _DTLSys_check_sys_conf(DTLSys *self); static PyObject * DTLSys_check_sys_conf(DTLSys *self); static PyObject * DTLSys_make_tms(DTLSys *self, PyObject *args); static PyMethodDef DTLSys_methods[] = { {"check_sys_conf", (PyCFunction)DTLSys_check_sys_conf, METH_NOARGS, "Check if system config is correct"}, {"make_tms", (PyCFunction)DTLSys_make_tms, METH_VARARGS, "Make TiMe Series"}, {NULL} /* Sentinel */ }; static PyTypeObject DTLSysType = { PyObject_HEAD_INIT(NULL) 0, /*ob_size*/ "dtls.DTLSys", /*tp_name*/ sizeof(DTLSys), /*tp_basicsize*/ 0, /*tp_itemsize*/ (destructor)DTLSys_dealloc, /*tp_dealloc*/ 0, /*tp_print*/ 0, /*tp_getattr*/ 0, /*tp_setattr*/ 0, /*tp_compare*/ 0, /*tp_repr*/ 0, /*tp_as_number*/ 0, /*tp_as_sequence*/ 0, /*tp_as_mapping*/ 0, /*tp_hash */ 0, /*tp_call*/ 0, /*tp_str*/ 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ "DTLSys objects", /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ DTLSys_methods, /* tp_methods */ DTLSys_members, /* tp_members */ DTLSys_getseters, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc)DTLSys_init, /* tp_init */ 0, /* tp_alloc */ DTLSys_new, /* tp_new */ }; static PyMethodDef module_methods[] = { {NULL} /* Sentinel */ }; #ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ #define PyMODINIT_FUNC void #endif PyMODINIT_FUNC initdtls(void) { PyObject* m; if (PyType_Ready(&DTLSysType) < 0){ return; } m = Py_InitModule3("dtls", module_methods, "Example module that creates an extension type."); if (m == NULL){ return; } Py_INCREF(&DTLSysType); PyModule_AddObject(m, "DTLSys", (PyObject *)&DTLSysType); import_array(); /* required NumPy initialization */ } /* ================================================================= Define */ static void DTLSys_dealloc(DTLSys* self) { Py_XDECREF(self->wt); Py_XDECREF(self->bs); Py_XDECREF(self->xt); self->ob_type->tp_free((PyObject*)self); } static PyObject * DTLSys_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { DTLSys *self; npy_intp wt_dims[2] = {0,0}; npy_intp bs_dims[1] = {0}; npy_intp xt_dims[2] = {0,0}; self = (DTLSys *)type->tp_alloc(type, 0); if (self != NULL) { self->wt = (PyArrayObject *) PyArray_SimpleNew(2, wt_dims, NPY_DOUBLE); if (self->wt == NULL){ Py_DECREF(self); return NULL; } self->bs = (PyArrayObject *) PyArray_SimpleNew(1, bs_dims, NPY_DOUBLE); if (self->bs == NULL){ Py_DECREF(self); return NULL; } self->xt = (PyArrayObject *) PyArray_SimpleNew(2, xt_dims, NPY_DOUBLE); if (self->xt == NULL){ Py_DECREF(self); return NULL; } } return (PyObject *)self; } static int DTLSys_init(DTLSys *self, PyObject *args, PyObject *kwds) { PyArrayObject *wt, *bs, *tmp; int t_max; npy_intp xt_dims[2] = {0,0}; if ( !PyArg_ParseTuple( args, "O!O!i:DTLSys.init", &PyArray_Type, &wt, &PyArray_Type, &bs, &t_max) ) { return -1; /* PyArg_ParseTuple has raised an exception */ } if ( wt==NULL || bs==NULL ) { } if ( t_max < 0 ) { } xt_dims[0] = PyArray_DIM(wt,0); xt_dims[1] = t_max; self->xt = (PyArrayObject *) PyArray_SimpleNew(2, xt_dims, NPY_DOUBLE); if (self->xt == NULL){ return -1; } tmp = self->wt; Py_INCREF(wt); self->wt = wt; Py_DECREF(tmp); tmp = self->bs; Py_INCREF(bs); self->bs = bs; Py_DECREF(tmp); if( _DTLSys_check_sys_conf(self) != 0 ){ PyErr_Clear(); } return 0; } DEF_PYARRAY_GETTER( DTLSys_get_wt, DTLSys, wt ) DEF_PYARRAY_SETTER( DTLSys_set_wt, DTLSys, wt, 2 ) DEF_PYARRAY_GETTER( DTLSys_get_bs, DTLSys, bs ) DEF_PYARRAY_SETTER( DTLSys_set_bs, DTLSys, bs, 1 ) DEF_PYARRAY_GETTER( DTLSys_get_xt, DTLSys, xt ) DEF_PYARRAY_SETTER( DTLSys_set_xt, DTLSys, xt, 2 ) static int _DTLSys_check_sys_conf(DTLSys *self) { int vecsize; NDIM_CHECK(self->wt, 2, -1); TYPE_CHECK(self->wt, NPY_DOUBLE, -1); NDIM_CHECK(self->bs, 1, -1); TYPE_CHECK(self->bs, NPY_DOUBLE, -1); vecsize = PyArray_DIM(self->wt,0); if (vecsize != PyArray_DIM(self->wt,1) ) { PyErr_Format( PyExc_ValueError, "self.wt must be square"); return -1; } if (vecsize != PyArray_DIM(self->bs,0) ) { PyErr_Format( PyExc_ValueError, "self.bs and self.wt[0] must be same shape"); return -1; } if (vecsize != PyArray_DIM(self->xt,0) ) { PyErr_Format( PyExc_ValueError, "self.xt[,0] and self.wt[0] must be same shape"); return -1; } return 0; } static PyObject * DTLSys_check_sys_conf(DTLSys *self) { if( _DTLSys_check_sys_conf(self) != 0 ){ PyErr_Clear(); Py_RETURN_FALSE; } Py_RETURN_TRUE; } static PyObject * DTLSys_make_tms(DTLSys *self, PyObject *args) { int vecsize, t_max, i, j, t; if( _DTLSys_check_sys_conf(self) != 0 ){ return NULL; } vecsize = PyArray_DIM(self->wt,0); t_max = PyArray_DIM(self->xt,1); for (t = 1; t < t_max; t++) { for (i = 0; i < vecsize; i++) { DIND2(self->xt,i,t) = DIND1(self->bs,i); for (j = 0; j < vecsize; j++) { DIND2(self->xt,i,t) += DIND2(self->wt,i,j) * DIND2(self->xt,j,t-1); } } } return Py_BuildValue(""); /* return None */ } /* # setup.py # build command : python setup.py build build_ext --inplace from numpy.distutils.core import setup, Extension import os, numpy name = 'dtls' sources = ['dtlsmodule.c'] include_dirs = [ numpy.get_include() ] setup( name = name, include_dirs = include_dirs, ext_modules = [Extension(name, sources)] ) */ /* # test code import scipy, pylab import dtls t_max = 200 rot = 5.0 * 2.0 * scipy.pi / t_max wt = scipy.array([ [ scipy.cos(rot), scipy.sin(rot) ], [ -scipy.sin(rot), scipy.cos(rot) ] ]) wt *= 0.99 bs = scipy.array([0.,0.]) a=dtls.DTLSys( wt, bs, t_max ) a.xt[0,0] = 0 a.xt[1,0] = 1 a.make_tms() pylab.clf() pylab.plot(a.xt[0],a.xt[1], 'o-') # Check calculation print (scipy.dot( a.wt, a.xt[:,0] ) + a.bs) - a.xt[:,1] */