Example of Numpy/C API


/ Published in: C
Save to your folder(s)

Some macros are from the "Python Scripting for Computational Science" : http://folk.uio.no/hpl/scripting/


Copy this code and paste it in your HTML
  1. /* dtlsmodule.c */
  2. #include <math.h>
  3. #include <stdio.h>
  4.  
  5. #include <Python.h>
  6. #include "structmember.h"
  7. #include <numpy/arrayobject.h>
  8.  
  9. /* ================================================================= MACROS */
  10. #define QUOTE(s) # s /* turn s into string "s" */
  11. #define NDIM_CHECK(a, expected_ndim, rt_error) \
  12.   if (PyArray_NDIM(a) != expected_ndim) { \
  13.   PyErr_Format(PyExc_ValueError, \
  14. "%s array is %d-dimensional, but expected to be %d-dimensional", \
  15. QUOTE(a), PyArray_NDIM(a), expected_ndim); \
  16.   return rt_error; \
  17.   }
  18. #define DIM_CHECK(a, dim, expected_length, rt_error) \
  19.   if (dim > PyArray_NDIM(a)) { \
  20.   PyErr_Format(PyExc_ValueError, \
  21. "%s array has no %d dimension (max dim. is %d)", \
  22. QUOTE(a), dim, PyArray_NDIM(a)); \
  23.   return rt_error; \
  24.   } \
  25.   if (PyArray_DIM(a, dim) != expected_length) { \
  26.   PyErr_Format(PyExc_ValueError, \
  27. "%s array has wrong %d-dimension=%d (expected %d)", \
  28. QUOTE(a), dim, PyArray_DIM(a, dim), expected_length); \
  29.   return rt_error; \
  30.   }
  31. #define TYPE_CHECK(a, tp, rt_error) \
  32.   if (PyArray_TYPE(a) != tp) { \
  33.   PyErr_Format(PyExc_TypeError, \
  34. "%s array is not of correct type (%d)", QUOTE(a), tp); \
  35.   return rt_error; \
  36.   }
  37. #define CALLABLE_CHECK(func, rt_error) \
  38.   if (!PyCallable_Check(func)) { \
  39.   PyErr_Format(PyExc_TypeError, \
  40. "%s is not a callable function", QUOTE(func)); \
  41.   return rt_error; \
  42.   }
  43.  
  44. #define DIND1(a, i) *((double *) PyArray_GETPTR1(a, i))
  45. #define DIND2(a, i, j) *((double *) PyArray_GETPTR2(a, i, j))
  46. #define DIND3(a, i, j, k) *((double *) Py_Array_GETPTR3(a, i, j, k))
  47.  
  48. #define IIND1(a, i) *((int *) PyArray_GETPTR1(a, i))
  49. #define IIND2(a, i, j) *((int *) PyArray_GETPTR2(a, i, j))
  50. #define IIND3(a, i, j, k) *((int *) Py_Array_GETPTR3(a, i, j, k))
  51.  
  52.  
  53. #define DEF_PYARRAY_GETTER(funcname, selftype, valname) \
  54.   static PyObject * \
  55.   funcname(selftype *self, void *closure) \
  56.   { \
  57.   Py_INCREF(self->valname); \
  58.   return PyArray_Return(self->valname); \
  59.   }
  60. #define DEF_PYARRAY_SETTER(funcname, selftype, valname, arraydim) \
  61.   static int \
  62.   funcname(selftype *self, PyObject *value, void *closure) \
  63.   { \
  64.   if (value == NULL) { \
  65.   PyErr_SetString( PyExc_TypeError, \
  66. "Cannot delete the last attribute"); \
  67.   return -1; \
  68.   } \
  69.   if ( PyArray_Check(value) != 1 ){ \
  70.   PyErr_Format( PyExc_ValueError, \
  71. "value is not of type numpy array"); \
  72.   return -1; \
  73.   } \
  74.   if ( PyArray_NDIM(value) != arraydim ){ \
  75.   PyErr_Format( PyExc_ValueError, \
  76. "value array's dimension %d != arraydim", \
  77. PyArray_NDIM(value)); \
  78.   return -1; \
  79.   } \
  80.   if ( PyArray_TYPE(value) != NPY_DOUBLE ){ \
  81.   PyErr_Format( PyExc_ValueError, \
  82. "value array is not of type 'Python float'"); \
  83.   return -1; \
  84.   } \
  85.   Py_DECREF(self->valname); \
  86.   Py_INCREF(value); \
  87.   self->valname = (PyArrayObject *) value; \
  88.   return 0; \
  89.   }
  90.  
  91. /* ========================================================== DTLSys struct */
  92. typedef struct {
  93. PyObject_HEAD
  94. PyArrayObject *wt;
  95. PyArrayObject *bs;
  96. PyArrayObject *xt;
  97. } DTLSys;
  98.  
  99. /* ============================================================ Declaration */
  100. static void DTLSys_dealloc(DTLSys* self);
  101. static PyObject * DTLSys_new(PyTypeObject *type, PyObject *args, PyObject *kwds);
  102. static int DTLSys_init(DTLSys *self, PyObject *args, PyObject *kwds);
  103.  
  104. static PyMemberDef DTLSys_members[] = {
  105. {NULL} /* Sentinel */
  106. };
  107.  
  108. static PyObject * DTLSys_get_wt(DTLSys *self, void *closure);
  109. static int DTLSys_set_wt(DTLSys *self, PyObject *value, void *closure);
  110. static PyObject * DTLSys_get_bs(DTLSys *self, void *closure);
  111. static int DTLSys_set_bs(DTLSys *self, PyObject *value, void *closure);
  112. static PyObject * DTLSys_get_xt(DTLSys *self, void *closure);
  113. static int DTLSys_set_xt(DTLSys *self, PyObject *value, void *closure);
  114.  
  115. static PyGetSetDef DTLSys_getseters[] = {
  116. {"wt", (getter)DTLSys_get_wt, (setter)DTLSys_set_wt, "Matrix", NULL},
  117. {"bs", (getter)DTLSys_get_bs, (setter)DTLSys_set_bs, "Vector", NULL},
  118. {"xt", (getter)DTLSys_get_xt, (setter)DTLSys_set_xt, "Vector time sequence", NULL},
  119. {NULL} /* Sentinel */
  120. };
  121.  
  122. static int _DTLSys_check_sys_conf(DTLSys *self);
  123. static PyObject * DTLSys_check_sys_conf(DTLSys *self);
  124. static PyObject * DTLSys_make_tms(DTLSys *self, PyObject *args);
  125.  
  126. static PyMethodDef DTLSys_methods[] = {
  127. {"check_sys_conf", (PyCFunction)DTLSys_check_sys_conf, METH_NOARGS, "Check if system config is correct"},
  128. {"make_tms", (PyCFunction)DTLSys_make_tms, METH_VARARGS, "Make TiMe Series"},
  129. {NULL} /* Sentinel */
  130. };
  131.  
  132. static PyTypeObject DTLSysType = {
  133. PyObject_HEAD_INIT(NULL)
  134. 0, /*ob_size*/
  135. "dtls.DTLSys", /*tp_name*/
  136. sizeof(DTLSys), /*tp_basicsize*/
  137. 0, /*tp_itemsize*/
  138. (destructor)DTLSys_dealloc, /*tp_dealloc*/
  139. 0, /*tp_print*/
  140. 0, /*tp_getattr*/
  141. 0, /*tp_setattr*/
  142. 0, /*tp_compare*/
  143. 0, /*tp_repr*/
  144. 0, /*tp_as_number*/
  145. 0, /*tp_as_sequence*/
  146. 0, /*tp_as_mapping*/
  147. 0, /*tp_hash */
  148. 0, /*tp_call*/
  149. 0, /*tp_str*/
  150. 0, /*tp_getattro*/
  151. 0, /*tp_setattro*/
  152. 0, /*tp_as_buffer*/
  153. Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
  154. "DTLSys objects", /* tp_doc */
  155. 0, /* tp_traverse */
  156. 0, /* tp_clear */
  157. 0, /* tp_richcompare */
  158. 0, /* tp_weaklistoffset */
  159. 0, /* tp_iter */
  160. 0, /* tp_iternext */
  161. DTLSys_methods, /* tp_methods */
  162. DTLSys_members, /* tp_members */
  163. DTLSys_getseters, /* tp_getset */
  164. 0, /* tp_base */
  165. 0, /* tp_dict */
  166. 0, /* tp_descr_get */
  167. 0, /* tp_descr_set */
  168. 0, /* tp_dictoffset */
  169. (initproc)DTLSys_init, /* tp_init */
  170. 0, /* tp_alloc */
  171. DTLSys_new, /* tp_new */
  172. };
  173.  
  174. static PyMethodDef module_methods[] = {
  175. {NULL} /* Sentinel */
  176. };
  177.  
  178. #ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
  179. #define PyMODINIT_FUNC void
  180. #endif
  181. PyMODINIT_FUNC
  182. initdtls(void)
  183. {
  184. PyObject* m;
  185. if (PyType_Ready(&DTLSysType) < 0){ return; }
  186.  
  187. m = Py_InitModule3("dtls", module_methods,
  188. "Example module that creates an extension type.");
  189. if (m == NULL){ return; }
  190.  
  191. Py_INCREF(&DTLSysType);
  192. PyModule_AddObject(m, "DTLSys", (PyObject *)&DTLSysType);
  193. import_array(); /* required NumPy initialization */
  194. }
  195.  
  196. /* ================================================================= Define */
  197. static void
  198. DTLSys_dealloc(DTLSys* self)
  199. {
  200. Py_XDECREF(self->wt);
  201. Py_XDECREF(self->bs);
  202. Py_XDECREF(self->xt);
  203. self->ob_type->tp_free((PyObject*)self);
  204. }
  205.  
  206. static PyObject *
  207. DTLSys_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
  208. {
  209. DTLSys *self;
  210. npy_intp wt_dims[2] = {0,0};
  211. npy_intp bs_dims[1] = {0};
  212. npy_intp xt_dims[2] = {0,0};
  213.  
  214. self = (DTLSys *)type->tp_alloc(type, 0);
  215. if (self != NULL) {
  216. self->wt = (PyArrayObject *) PyArray_SimpleNew(2, wt_dims, NPY_DOUBLE);
  217. if (self->wt == NULL){ Py_DECREF(self); return NULL; }
  218. self->bs = (PyArrayObject *) PyArray_SimpleNew(1, bs_dims, NPY_DOUBLE);
  219. if (self->bs == NULL){ Py_DECREF(self); return NULL; }
  220. self->xt = (PyArrayObject *) PyArray_SimpleNew(2, xt_dims, NPY_DOUBLE);
  221. if (self->xt == NULL){ Py_DECREF(self); return NULL; }
  222. }
  223.  
  224. return (PyObject *)self;
  225. }
  226.  
  227. static int
  228. DTLSys_init(DTLSys *self, PyObject *args, PyObject *kwds)
  229. {
  230. PyArrayObject *wt, *bs, *tmp;
  231. int t_max;
  232. npy_intp xt_dims[2] = {0,0};
  233.  
  234. if ( !PyArg_ParseTuple( args, "O!O!i:DTLSys.init",
  235. &PyArray_Type, &wt,
  236. &PyArray_Type, &bs,
  237. &t_max)
  238. ) {
  239. return -1; /* PyArg_ParseTuple has raised an exception */
  240. }
  241. if ( wt==NULL || bs==NULL ) {
  242. printf("getting args failed\n"); return -1;
  243. }
  244. if ( t_max < 0 ) {
  245. printf("t_max (3rd arg) must be positive int\n"); return -1;
  246. }
  247.  
  248. xt_dims[0] = PyArray_DIM(wt,0);
  249. xt_dims[1] = t_max;
  250. self->xt = (PyArrayObject *) PyArray_SimpleNew(2, xt_dims, NPY_DOUBLE);
  251. if (self->xt == NULL){
  252. printf("creating %dx%d array failed\n", (int)xt_dims[0], (int)xt_dims[1]);
  253. return -1;
  254. }
  255.  
  256. tmp = self->wt; Py_INCREF(wt); self->wt = wt; Py_DECREF(tmp);
  257. tmp = self->bs; Py_INCREF(bs); self->bs = bs; Py_DECREF(tmp);
  258.  
  259. if( _DTLSys_check_sys_conf(self) != 0 ){
  260. PyErr_Clear();
  261. printf("DTLSys config is not correct!\n");
  262. }
  263. return 0;
  264. }
  265.  
  266. DEF_PYARRAY_GETTER( DTLSys_get_wt, DTLSys, wt )
  267. DEF_PYARRAY_SETTER( DTLSys_set_wt, DTLSys, wt, 2 )
  268. DEF_PYARRAY_GETTER( DTLSys_get_bs, DTLSys, bs )
  269. DEF_PYARRAY_SETTER( DTLSys_set_bs, DTLSys, bs, 1 )
  270. DEF_PYARRAY_GETTER( DTLSys_get_xt, DTLSys, xt )
  271. DEF_PYARRAY_SETTER( DTLSys_set_xt, DTLSys, xt, 2 )
  272.  
  273. static int
  274. _DTLSys_check_sys_conf(DTLSys *self)
  275. {
  276. int vecsize;
  277.  
  278. NDIM_CHECK(self->wt, 2, -1); TYPE_CHECK(self->wt, NPY_DOUBLE, -1);
  279. NDIM_CHECK(self->bs, 1, -1); TYPE_CHECK(self->bs, NPY_DOUBLE, -1);
  280.  
  281. vecsize = PyArray_DIM(self->wt,0);
  282. if (vecsize != PyArray_DIM(self->wt,1) ) {
  283. PyErr_Format( PyExc_ValueError, "self.wt must be square");
  284. return -1;
  285. }
  286. if (vecsize != PyArray_DIM(self->bs,0) ) {
  287. PyErr_Format( PyExc_ValueError, "self.bs and self.wt[0] must be same shape");
  288. return -1;
  289. }
  290. if (vecsize != PyArray_DIM(self->xt,0) ) {
  291. PyErr_Format( PyExc_ValueError, "self.xt[,0] and self.wt[0] must be same shape");
  292. return -1;
  293. }
  294. return 0;
  295. }
  296.  
  297. static PyObject *
  298. DTLSys_check_sys_conf(DTLSys *self)
  299. {
  300. if( _DTLSys_check_sys_conf(self) != 0 ){
  301. PyErr_Clear();
  302. Py_RETURN_FALSE;
  303. }
  304. Py_RETURN_TRUE;
  305. }
  306.  
  307. static PyObject *
  308. DTLSys_make_tms(DTLSys *self, PyObject *args)
  309. {
  310. int vecsize, t_max, i, j, t;
  311. if( _DTLSys_check_sys_conf(self) != 0 ){
  312. return NULL;
  313. }
  314. vecsize = PyArray_DIM(self->wt,0);
  315. t_max = PyArray_DIM(self->xt,1);
  316.  
  317. for (t = 1; t < t_max; t++) {
  318. for (i = 0; i < vecsize; i++) {
  319. DIND2(self->xt,i,t) = DIND1(self->bs,i);
  320. for (j = 0; j < vecsize; j++) {
  321. DIND2(self->xt,i,t) += DIND2(self->wt,i,j) * DIND2(self->xt,j,t-1);
  322. }
  323. }
  324. }
  325. return Py_BuildValue(""); /* return None */
  326. }
  327.  
  328. /*
  329. # setup.py
  330. # build command : python setup.py build build_ext --inplace
  331. from numpy.distutils.core import setup, Extension
  332. import os, numpy
  333.  
  334. name = 'dtls'
  335. sources = ['dtlsmodule.c']
  336.  
  337. include_dirs = [
  338.   numpy.get_include()
  339.   ]
  340.  
  341. setup( name = name,
  342.   include_dirs = include_dirs,
  343.   ext_modules = [Extension(name, sources)]
  344.   )
  345. */
  346.  
  347. /*
  348. # test code
  349. import scipy, pylab
  350. import dtls
  351.  
  352. t_max = 200
  353. rot = 5.0 * 2.0 * scipy.pi / t_max
  354. wt = scipy.array([
  355.   [ scipy.cos(rot), scipy.sin(rot) ],
  356.   [ -scipy.sin(rot), scipy.cos(rot) ]
  357.   ])
  358. wt *= 0.99
  359. bs = scipy.array([0.,0.])
  360.  
  361. a=dtls.DTLSys( wt, bs, t_max )
  362. a.xt[0,0] = 0
  363. a.xt[1,0] = 1
  364. a.make_tms()
  365.  
  366. pylab.clf()
  367. pylab.plot(a.xt[0],a.xt[1], 'o-')
  368.  
  369. # Check calculation
  370. print (scipy.dot( a.wt, a.xt[:,0] ) + a.bs) - a.xt[:,1]
  371. */

Report this snippet


Comments

RSS Icon Subscribe to comments

You need to login to post a comment.