Python 示例:矩阵乘法

Python UDx 可以接受并返回复杂类型。MatrixMultiply 类会乘以输入矩阵,并返回生成的矩阵乘积。这些矩阵将以二维数组来表示。为了执行矩阵乘法运算,第一个输入矩阵中的列数必须等于第二个输入矩阵中的行数。

完整的源代码位于 /opt/vertica/sdk/examples/python/ScalarFunctions.py 中。

加载和使用示例

加载库并创建函数,如下所示:

=> CREATE OR REPLACE LIBRARY ScalarFunctions AS '/home/dbadmin/examples/python/ScalarFunctions.py' LANGUAGE 'Python';

=> CREATE FUNCTION MatrixMultiply AS LANGUAGE 'Python' NAME 'matrix_multiply_factory' LIBRARY ScalarFunctions;

您可以创建输入矩阵,然后调用诸如以下函数:


=> CREATE TABLE mn (id INTEGER, data ARRAY[ARRAY[INTEGER, 3], 2]);
CREATE TABLE

=> CREATE TABLE np (id INTEGER, data ARRAY[ARRAY[INTEGER, 2], 3]);
CREATE TABLE

=> COPY mn FROM STDIN PARSER fjsonparser();
{"id": 1, "data": [[1, 2, 3], [4, 5, 6]] }
{"id": 2, "data": [[7, 8, 9], [10, 11, 12]] }
\.

=> COPY np FROM STDIN PARSER fjsonparser();
{"id": 1, "data": [[0, 0], [0, 0], [0, 0]] }
{"id": 2, "data": [[1, 1], [1, 1], [1, 1]] }
{"id": 3, "data": [[2, 0], [0, 2], [2, 0]] }
\.

=> SELECT mn.id, np.id, MatrixMultiply(mn.data, np.data) FROM mn CROSS JOIN np ORDER BY 1, 2;
id | id |   MatrixMultiply
---+----+-------------------
1  |  1 | [[0,0],[0,0]]
1  |  2 | [[6,6],[15,15]]
1  |  3 | [[8,4],[20,10]]
2  |  1 | [[0,0],[0,0]]
2  |  2 | [[24,24],[33,33]]
2  |  3 | [[32,16],[44,22]]
(6 rows)

设置

所有 Python UDx 都必须导入 Vertica SDK 库:

import vertica_sdk

工厂实施

getPrototype() 方法会声明函数实参和返回类型都必须为二维数组,以整数数组的数组来表示:


def getPrototype(self, srv_interface, arg_types, return_type):
    array1dtype = vertica_sdk.ColumnTypes.makeArrayType(vertica_sdk.ColumnTypes.makeInt())
    arg_types.addArrayType(array1dtype)
    arg_types.addArrayType(array1dtype)
    return_type.addArrayType(array1dtype)

getReturnType() 验证乘积矩阵的行数是否与第一个输入矩阵相同,以及列数是否与第二个输入矩阵相同:


def getReturnType(self, srv_interface, arg_types, return_type):
    (_, a1type) = arg_types[0]
    (_, a2type) = arg_types[1]
    m = a1type.getArrayBound()
    p = a2type.getElementType().getArrayBound()
    return_type.addArrayType(vertica_sdk.SizedColumnTypes.makeArrayType(vertica_sdk.SizedColumnTypes.makeInt(), p), m)

函数实施

使用名称分别为 arg_readerres_writerBlockReaderBlockWriter 来调用 processBlock() 方法。为了访问输入数组的元素,该方法会使用 ArrayReader 实例。数组是嵌套的,因此必须为外部和内部数组实例化 ArrayReader。列表推导式简化了将输入数组读取到列表中的过程。该方法会执行计算,然后使用 ArrayWriter 实例来构造乘积矩阵。


def processBlock(self, server_interface, arg_reader, res_writer):
    while True:
        lmat = [[cell.getInt(0) for cell in row.getArrayReader(0)] for row in arg_reader.getArrayReader(0)]
        rmat = [[cell.getInt(0) for cell in row.getArrayReader(0)] for row in arg_reader.getArrayReader(1)]
        omat = [[0 for c in range(len(rmat[0]))] for r in range(len(lmat))]

        for i in range(len(lmat)):
            for j in range(len(rmat[0])):
                for k in range(len(rmat)):
                    omat[i][j] += lmat[i][k] * rmat[k][j]

        res_writer.setArray(omat)
        res_writer.next()

        if not arg_reader.next():
            break