LCOV - code coverage report
Current view: top level - sql/backends/monet5/UDF/pyapi3 - pyloader3.c (source / functions) Hit Total Coverage
Test: coverage.info Lines: 151 199 75.9 %
Date: 2020-06-29 20:00:14 Functions: 2 2 100.0 %

          Line data    Source code
       1             : /*
       2             :  * This Source Code Form is subject to the terms of the Mozilla Public
       3             :  * License, v. 2.0.  If a copy of the MPL was not distributed with this
       4             :  * file, You can obtain one at http://mozilla.org/MPL/2.0/.
       5             :  *
       6             :  * Copyright 1997 - July 2008 CWI, August 2008 - 2020 MonetDB B.V.
       7             :  */
       8             : 
       9             : #include "monetdb_config.h"
      10             : #include "pyapi.h"
      11             : #include "conversion.h"
      12             : #include "connection.h"
      13             : #include "emit.h"
      14             : 
      15             : #include "unicode.h"
      16             : #include "pytypes.h"
      17             : #include "gdk_interprocess.h"
      18             : #include "type_conversion.h"
      19             : #include "formatinput.h"
      20             : 
      21           7 : static void _loader_import_array(void) { _import_array(); }
      22             : 
      23           7 : str _loader_init(void)
      24             : {
      25           7 :         str msg = MAL_SUCCEED;
      26           7 :         _loader_import_array();
      27           7 :         msg = _emit_init();
      28           7 :         if (msg != MAL_SUCCEED) {
      29             :                 return msg;
      30             :         }
      31             : 
      32           7 :         if (PyType_Ready(&Py_ConnectionType) < 0)
      33           0 :                 return createException(MAL, "pyapi3.eval",
      34             :                                                            SQLSTATE(PY000) "Failed to initialize loader functions.");
      35             :         return msg;
      36             : }
      37             : 
      38             : static int
      39          72 : pyapi_list_length(list *l)
      40             : {
      41          72 :         if (l)
      42          48 :                 return l->cnt;
      43             :         return 0;
      44             : }
      45             : 
      46             : str
      47          25 : PYFUNCNAME(PyAPIevalLoader)(Client cntxt, MalBlkPtr mb, MalStkPtr stk, InstrPtr pci) {
      48          25 :     sql_func * sqlfun;
      49          25 :     sql_subfunc * sqlmorefun;
      50          25 :     str exprStr;
      51             : 
      52          25 :         const int additional_columns = 2;
      53          25 :         int i = 1, ai = 0;
      54          25 :         char *pycall = NULL;
      55          25 :         str *args = NULL;
      56          25 :         char *msg = MAL_SUCCEED;
      57          25 :         node *argnode, *n, *n2;
      58          25 :         PyObject *pArgs = NULL, *pEmit = NULL,
      59             :                          *pConnection; // this is going to be the parameter tuple
      60          25 :         PyObject *code_object = NULL;
      61          25 :         sql_emit_col *cols = NULL;
      62          25 :         bool gstate = 0;
      63          25 :         int unnamedArgs = 0;
      64          25 :         int argcount = pci->argc;
      65          25 :         bool create_table = false;
      66          25 :         BUN nval = 0;
      67          25 :         int ncols = 0;
      68             : 
      69          25 :         char *loader_additional_args[] = {"_emit", "_conn"};
      70             : 
      71          25 :     if (!PYFUNCNAME(PyAPIInitialized())) {
      72           0 :         throw(MAL, "pyapi3.eval",
      73             :               SQLSTATE(PY000) "Embedded Python is enabled but an error was thrown during initialization.");
      74             :     }
      75          25 :     sqlmorefun = *(sql_subfunc**) getArgReference(stk, pci, pci->retc);
      76          25 :     sqlfun = sqlmorefun->func;
      77          25 :     exprStr = *getArgReference_str(stk, pci, pci->retc + 1);
      78             : 
      79          25 :         args = (str *)GDKzalloc(pci->argc * sizeof(str));
      80          25 :         if (!args) {
      81           0 :                 throw(MAL, "pyapi3.eval", SQLSTATE(HY013) MAL_MALLOC_FAIL " arguments.");
      82             :         }
      83             : 
      84             :         // Analyse the SQL_Func structure to get the parameter names
      85          25 :         if (sqlfun != NULL && sqlfun->ops->cnt > 0) {
      86          22 :                 unnamedArgs = pci->retc + 2;
      87          22 :                 argnode = sqlfun->ops->h;
      88          66 :                 while (argnode) {
      89          44 :                         char *argname = ((sql_arg *)argnode->data)->name;
      90          44 :                         args[unnamedArgs++] = GDKstrdup(argname);
      91          44 :                         argnode = argnode->next;
      92             :                 }
      93             :         }
      94             : 
      95             :         // We name all the unknown arguments
      96          69 :         for (i = pci->retc + 2; i < argcount; i++) {
      97          44 :                 if (!args[i]) {
      98           0 :                         char argbuf[64];
      99           0 :                         snprintf(argbuf, sizeof(argbuf), "arg%i", i - pci->retc - 1);
     100           0 :                         args[i] = GDKstrdup(argbuf);
     101             :                 }
     102             :         }
     103          25 :         gstate = Python_ObtainGIL();
     104             : 
     105          25 :         pArgs = PyTuple_New(argcount - pci->retc - 2 + additional_columns);
     106          25 :         if (!pArgs) {
     107           0 :                 msg = createException(MAL, "pyapi3.eval_loader",
     108             :                                                           SQLSTATE(HY013) MAL_MALLOC_FAIL "python object");
     109           0 :                 goto wrapup;
     110             :         }
     111             : 
     112          25 :         ai = 0;
     113          25 :         argnode = sqlfun && sqlfun->ops->cnt > 0 ? sqlfun->ops->h : NULL;
     114          69 :         for (i = pci->retc + 2; i < argcount; i++) {
     115          44 :                 PyInput inp;
     116          44 :                 PyObject *val = NULL;
     117          44 :                 inp.bat = NULL;
     118          44 :                 inp.sql_subtype = NULL;
     119             : 
     120          44 :                 if (!isaBatType(getArgType(mb, pci, i))) {
     121          44 :                         inp.scalar = true;
     122          44 :                         inp.bat_type = getArgType(mb, pci, i);
     123          44 :                         inp.count = 1;
     124          44 :                         if (inp.bat_type == TYPE_str) {
     125          18 :                                 inp.dataptr = getArgReference_str(stk, pci, i);
     126             :                         } else {
     127          26 :                                 inp.dataptr = getArgReference(stk, pci, i);
     128             :                         }
     129          44 :                         val = PyArrayObject_FromScalar(&inp, &msg);
     130             :                 } else {
     131           0 :                         BAT* b = BATdescriptor(*getArgReference_bat(stk, pci, i));
     132           0 :                         if (b == NULL) {
     133           0 :                                 msg = createException(
     134             :                                         MAL, "pyapi3.eval_loader",
     135             :                                         SQLSTATE(PY000) "The BAT passed to the function (argument #%d) is NULL.\n",
     136           0 :                                         i - (pci->retc + 2) + 1);
     137           0 :                                 goto wrapup;
     138             :                         }
     139           0 :                         inp.scalar = false;
     140           0 :                         inp.count = BATcount(b);
     141           0 :                         inp.bat_type = getBatType(getArgType(mb, pci, i));
     142           0 :                         inp.bat = b;
     143             : 
     144           0 :                         val = PyMaskedArray_FromBAT(
     145             :                                 &inp, 0, inp.count, &msg,
     146             :                                 false);
     147           0 :                         BBPunfix(inp.bat->batCacheid);
     148             :                 }
     149          44 :                 if (msg != MAL_SUCCEED) {
     150           0 :                         goto wrapup;
     151             :                 }
     152          44 :                 if (PyTuple_SetItem(pArgs, ai++, val) != 0) {
     153           0 :                         msg =
     154           0 :                                 createException(MAL, "pyapi3.eval_loader",
     155             :                                                                 SQLSTATE(PY000) "Failed to set tuple (this shouldn't happen).");
     156           0 :                         goto wrapup;
     157             :                 }
     158             :                 // TODO deal with sql types
     159             :         }
     160             : 
     161          25 :         getArg(pci, 0) = TYPE_void;
     162          25 :         if (sqlmorefun->colnames) {
     163          24 :                 n = sqlmorefun->colnames->h;
     164          24 :                 n2 = sqlmorefun->coltypes->h;
     165          24 :                 ncols = pyapi_list_length(sqlmorefun->coltypes);
     166          24 :                 if (ncols == 0) {
     167           0 :                         msg = createException(MAL, "pyapi3.eval_loader",
     168             :                                                                   "No columns supplied.");
     169           0 :                         goto wrapup;
     170             :                 }
     171          24 :                 cols = GDKzalloc(sizeof(sql_emit_col) * ncols);
     172          24 :                 if (!cols) {
     173           0 :                         msg = createException(MAL, "pyapi3.eval_loader",
     174             :                                                                   SQLSTATE(HY013) MAL_MALLOC_FAIL "column list");
     175           0 :                         goto wrapup;
     176             :                 }
     177          72 :                 assert(pyapi_list_length(sqlmorefun->colnames) == pyapi_list_length(sqlmorefun->coltypes) * 2);
     178             :                 i = 0;
     179         158 :                 while (n) {
     180         134 :                         sql_subtype* tpe = (sql_subtype*) n2->data;
     181         134 :                         cols[i].name = GDKstrdup(*((char **)n->data));
     182         134 :                         n = n->next;
     183         134 :                         assert(n);
     184         134 :                         cols[i].def = n->data;
     185         134 :                         n = n->next;
     186         268 :                         cols[i].b =
     187         134 :                                 COLnew(0, tpe->type->localtype, 0, TRANSIENT);
     188         134 :                         n2 = n2->next;
     189         134 :                         cols[i].b->tnil = false;
     190         134 :                         cols[i].b->tnonil = false;
     191         134 :                         i++;
     192             :                 }
     193             :         } else {
     194             :                 // set the return value to the correct type to prevent MAL layers from
     195             :                 // complaining
     196             :                 cols = NULL;
     197             :                 ncols = 0;
     198             :                 create_table = true;
     199             :         }
     200             : 
     201          25 :         pConnection = Py_Connection_Create(cntxt, 0, 0, 0);
     202          25 :         pEmit = PyEmit_Create(cols, ncols);
     203          25 :         if (!pConnection || !pEmit) {
     204           0 :                 msg = createException(MAL, "pyapi3.eval_loader",
     205             :                                                           SQLSTATE(HY013) MAL_MALLOC_FAIL "python object");
     206           0 :                 goto wrapup;
     207             :         }
     208             : 
     209          25 :         PyTuple_SetItem(pArgs, ai++, pEmit);
     210          25 :         PyTuple_SetItem(pArgs, ai++, pConnection);
     211             : 
     212          25 :         pycall = FormatCode(exprStr, args, argcount, 4, &code_object, &msg,
     213             :                                                 loader_additional_args, additional_columns);
     214          25 :         if (!pycall && !code_object) {
     215           0 :                 if (msg == MAL_SUCCEED) {
     216           0 :                         msg = createException(MAL, "pyapi3.eval_loader",
     217             :                                                                   SQLSTATE(PY000) "Error while parsing Python code.");
     218             :                 }
     219           0 :                 goto wrapup;
     220             :         }
     221             : 
     222             :         {
     223          25 :                 PyObject *pFunc, *pModule, *v, *d, *ret;
     224             : 
     225             :                 // First we will load the main module, this is required
     226          25 :                 pModule = PyImport_AddModule("__main__");
     227          25 :                 if (!pModule) {
     228           0 :                         msg = PyError_CreateException("Failed to load module", NULL);
     229           0 :                         goto wrapup;
     230             :                 }
     231             : 
     232             :                 // Now we will add the UDF to the main module
     233          25 :                 d = PyModule_GetDict(pModule);
     234          25 :                 if (code_object == NULL) {
     235          25 :                         v = PyRun_StringFlags(pycall, Py_file_input, d, d, NULL);
     236          25 :                         if (v == NULL) {
     237           0 :                                 msg = PyError_CreateException("Could not parse Python code",
     238             :                                                                                           pycall);
     239           0 :                                 goto wrapup;
     240             :                         }
     241          25 :                         Py_DECREF(v);
     242             : 
     243             :                         // Now we need to obtain a pointer to the function, the function is
     244             :                         // called "pyfun"
     245          25 :                         pFunc = PyObject_GetAttrString(pModule, "pyfun");
     246          25 :                         if (!pFunc || !PyCallable_Check(pFunc)) {
     247           0 :                                 msg = PyError_CreateException("Failed to load function", NULL);
     248           0 :                                 goto wrapup;
     249             :                         }
     250             :                 } else {
     251           0 :                         pFunc = PyFunction_New(code_object, d);
     252           0 :                         if (!pFunc || !PyCallable_Check(pFunc)) {
     253           0 :                                 msg = PyError_CreateException("Failed to load function", NULL);
     254           0 :                                 goto wrapup;
     255             :                         }
     256             :                 }
     257          25 :                 ret = PyObject_CallObject(pFunc, pArgs);
     258             : 
     259          25 :                 if (PyErr_Occurred()) {
     260          12 :                         Py_DECREF(pFunc);
     261          12 :                         msg = PyError_CreateException("Python exception", pycall);
     262          12 :                         if (code_object == NULL) {
     263          12 :                                 PyRun_SimpleString("del pyfun");
     264             :                         }
     265          12 :                         goto wrapup;
     266             :                 }
     267             : 
     268          13 :                 if (ret != Py_None) {
     269           0 :                         if (PyEmit_Emit((PyEmitObject *)pEmit, ret) == NULL) {
     270           0 :                                 Py_DECREF(pFunc);
     271           0 :                                 msg = PyError_CreateException("Python exception", pycall);
     272           0 :                                 goto wrapup;
     273             :                         }
     274             :                 }
     275             : 
     276          13 :                 cols = ((PyEmitObject *)pEmit)->cols;
     277          13 :                 nval = ((PyEmitObject *)pEmit)->nvals;
     278          13 :                 ncols = (int)((PyEmitObject *)pEmit)->ncols;
     279          13 :                 Py_DECREF(pFunc);
     280          13 :                 Py_DECREF(pArgs);
     281          13 :                 pArgs = NULL;
     282             : 
     283          13 :                 if (ncols == 0) {
     284           0 :                         msg = createException(MAL, "pyapi3.eval_loader",
     285             :                                                                   SQLSTATE(PY000) "No elements emitted by the loader.");
     286           0 :                         goto wrapup;
     287             :                 }
     288             :         }
     289             : 
     290          13 :         gstate = Python_ReleaseGIL(gstate);
     291             : 
     292          78 :         for (i = 0; i < ncols; i++) {
     293          52 :                 BAT *b = cols[i].b;
     294          52 :                 BATsetcount(b, nval);
     295          52 :                 b->tkey = false;
     296          52 :                 b->tsorted = false;
     297          52 :                 b->trevsorted = false;
     298             :         }
     299          13 :         if (!create_table) {
     300          12 :                 msg = _connection_append_to_table(cntxt, sqlmorefun->sname,
     301             :                                                                            sqlmorefun->tname, cols, ncols);
     302          12 :                 goto wrapup;
     303             :         } else {
     304           1 :                 msg = _connection_create_table(cntxt, sqlmorefun->sname,
     305             :                                                                            sqlmorefun->tname, cols, ncols);
     306           1 :                 goto wrapup;
     307             :         }
     308             : 
     309          25 : wrapup:
     310          25 :         if (cols) {
     311         161 :                 for (i = 0; i < ncols; i++) {
     312         136 :                         if (cols[i].b) {
     313         136 :                                 BBPunfix(cols[i].b->batCacheid);
     314             :                         }
     315         136 :                         if (cols[i].name) {
     316         136 :                                 GDKfree(cols[i].name);
     317             :                         }
     318             :                 }
     319          25 :                 GDKfree(cols);
     320             :         }
     321          25 :         if (gstate) {
     322          12 :                 if (pArgs) {
     323          12 :                         Py_DECREF(pArgs);
     324             :                 }
     325          12 :                 gstate = Python_ReleaseGIL(gstate);
     326             :         }
     327          25 :         if (pycall)
     328          25 :                 GDKfree(pycall);
     329          25 :         if (args) {
     330          69 :                 for (i = pci->retc + 2; i < argcount; i++) {
     331          44 :                         if (args[i]) {
     332          44 :                                 GDKfree(args[i]);
     333             :                         }
     334             :                 }
     335          25 :                 GDKfree(args);
     336             :         }
     337          25 :         return (msg);
     338             : }

Generated by: LCOV version 1.14