Python 示例:矩阵乘法
Python UDx 可以接受并返回复杂类型。MatrixMultiply
类会乘以输入矩阵,并返回生成的矩阵乘积。这些矩阵将以二维数组来表示。为了执行矩阵乘法运算,第一个输入矩阵中的列数必须等于第二个输入矩阵中的行数。
完整的源代码位于 /opt/vertica/sdk/examples/python/ScalarFunctions.py
中。
加载和使用示例
加载库并创建函数,如下所示:
=> CREATE OR REPLACE LIBRARY ScalarFunctions AS '/home/dbadmin/examples/python/ScalarFunctions.py' LANGUAGE 'Python';
=> CREATE FUNCTION MatrixMultiply AS LANGUAGE 'Python' NAME 'matrix_multiply_factory' LIBRARY ScalarFunctions;
您可以创建输入矩阵,然后调用诸如以下函数:
=> CREATE TABLE mn (id INTEGER, data ARRAY[ARRAY[INTEGER, 3], 2]);
CREATE TABLE
=> CREATE TABLE np (id INTEGER, data ARRAY[ARRAY[INTEGER, 2], 3]);
CREATE TABLE
=> COPY mn FROM STDIN PARSER fjsonparser();
{"id": 1, "data": [[1, 2, 3], [4, 5, 6]] }
{"id": 2, "data": [[7, 8, 9], [10, 11, 12]] }
\.
=> COPY np FROM STDIN PARSER fjsonparser();
{"id": 1, "data": [[0, 0], [0, 0], [0, 0]] }
{"id": 2, "data": [[1, 1], [1, 1], [1, 1]] }
{"id": 3, "data": [[2, 0], [0, 2], [2, 0]] }
\.
=> SELECT mn.id, np.id, MatrixMultiply(mn.data, np.data) FROM mn CROSS JOIN np ORDER BY 1, 2;
id | id | MatrixMultiply
---+----+-------------------
1 | 1 | [[0,0],[0,0]]
1 | 2 | [[6,6],[15,15]]
1 | 3 | [[8,4],[20,10]]
2 | 1 | [[0,0],[0,0]]
2 | 2 | [[24,24],[33,33]]
2 | 3 | [[32,16],[44,22]]
(6 rows)
设置
所有 Python UDx 都必须导入 Vertica SDK 库:
import vertica_sdk
工厂实施
getPrototype()
方法会声明函数实参和返回类型都必须为二维数组,以整数数组的数组来表示:
def getPrototype(self, srv_interface, arg_types, return_type):
array1dtype = vertica_sdk.ColumnTypes.makeArrayType(vertica_sdk.ColumnTypes.makeInt())
arg_types.addArrayType(array1dtype)
arg_types.addArrayType(array1dtype)
return_type.addArrayType(array1dtype)
getReturnType()
验证乘积矩阵的行数是否与第一个输入矩阵相同,以及列数是否与第二个输入矩阵相同:
def getReturnType(self, srv_interface, arg_types, return_type):
(_, a1type) = arg_types[0]
(_, a2type) = arg_types[1]
m = a1type.getArrayBound()
p = a2type.getElementType().getArrayBound()
return_type.addArrayType(vertica_sdk.SizedColumnTypes.makeArrayType(vertica_sdk.SizedColumnTypes.makeInt(), p), m)
函数实施
使用名称分别为 arg_reader
和 res_writer
的 BlockReader
和 BlockWriter
来调用 processBlock()
方法。为了访问输入数组的元素,该方法会使用 ArrayReader
实例。数组是嵌套的,因此必须为外部和内部数组实例化 ArrayReader
。列表推导式简化了将输入数组读取到列表中的过程。该方法会执行计算,然后使用 ArrayWriter
实例来构造乘积矩阵。
def processBlock(self, server_interface, arg_reader, res_writer):
while True:
lmat = [[cell.getInt(0) for cell in row.getArrayReader(0)] for row in arg_reader.getArrayReader(0)]
rmat = [[cell.getInt(0) for cell in row.getArrayReader(0)] for row in arg_reader.getArrayReader(1)]
omat = [[0 for c in range(len(rmat[0]))] for r in range(len(lmat))]
for i in range(len(lmat)):
for j in range(len(rmat[0])):
for k in range(len(rmat)):
omat[i][j] += lmat[i][k] * rmat[k][j]
res_writer.setArray(omat)
res_writer.next()
if not arg_reader.next():
break