Python 示例:多阶段计算

以下示例显示多阶段转换函数,该函数用于计算输入表中数字列的平均值。它会先定义两个转换函数,然后定义使用这些函数创建阶段的工厂。

有关完整的代码,请参阅示例分发中的 AvgMultiPhaseUDT.py

加载和使用示例

创建库和函数:

=> CREATE LIBRARY pylib_avg AS '/home/dbadmin/udx/AvgMultiPhaseUDT.py' LANGUAGE 'Python';
CREATE LIBRARY
=> CREATE TRANSFORM FUNCTION myAvg AS NAME 'MyAvgFactory' LIBRARY pylib_avg;
CREATE TRANSFORM FUNCTION

然后,您可以在 SELECT 语句中使用该函数:

=> CREATE TABLE IF NOT EXISTS numbers(num FLOAT);
CREATE TABLE

=> COPY numbers FROM STDIN delimiter ',';
1
2
3
4
\.

=> SELECT myAvg(num) OVER() FROM numbers;
 average | ignored_rows | total_rows
---------+--------------+------------
     2.5 |            0 |          4
(1 row)

设置

所有 Python UDx 都必须导入 Vertica SDK。此示例还会导入另一个库。

import vertica_sdk
import math

分量转换函数

多阶段转换函数必须定义两个或更多要在阶段中使用的 TransformFunction 子类。此示例会使用两个类:LocalCalculation(用于对本地分区进行计算)以及 GlobalCalculation(用于聚合所有 LocalCalculation 实例的结果以计算最终结果)。

在这两个函数中,计算都是在 processPartition() 函数中完成:

class LocalCalculation(vertica_sdk.TransformFunction):
    """
    This class is the first phase and calculates the local values for sum, ignored_rows and total_rows.
    """

    def setup(self, server_interface, col_types):
        server_interface.log("Setup: Phase0")
        self.local_sum = 0.0
        self.ignored_rows = 0
        self.total_rows = 0

    def processPartition(self, server_interface, input, output):
        server_interface.log("Process Partition: Phase0")

        while True:
            self.total_rows += 1

            if input.isNull(0) or math.isinf(input.getFloat(0)) or math.isnan(input.getFloat(0)):
                # Null, Inf, or Nan is ignored
                self.ignored_rows += 1
            else:
                self.local_sum += input.getFloat(0)

            if not input.next():
                break

        output.setFloat(0, self.local_sum)
        output.setInt(1, self.ignored_rows)
        output.setInt(2, self.total_rows)
        output.next()

class GlobalCalculation(vertica_sdk.TransformFunction):
    """
    This class is the second phase and aggregates the values for sum, ignored_rows and total_rows.
    """

    def setup(self, server_interface, col_types):
        server_interface.log("Setup: Phase1")
        self.global_sum = 0.0
        self.ignored_rows = 0
        self.total_rows = 0

    def processPartition(self, server_interface, input, output):
        server_interface.log("Process Partition: Phase1")

        while True:
            self.global_sum += input.getFloat(0)
            self.ignored_rows += input.getInt(1)
            self.total_rows += input.getInt(2)

            if not input.next():
                break

        average = self.global_sum / (self.total_rows - self.ignored_rows)

        output.setFloat(0, average)
        output.setInt(1, self.ignored_rows)
        output.setInt(2, self.total_rows)
        output.next()

多阶段工厂

MultiPhaseTransformFunctionFactory 会将各个函数作为阶段绑定在一起。工厂会为每个函数定义 TransformFunctionPhase。每个阶段都会定义 createTransformFunction(),用于调用对应的 TransformFunctiongetReturnType() 的构造函数。

下面是第一阶段 LocalPhase


class MyAvgFactory(vertica_sdk.MultiPhaseTransformFunctionFactory):
    """ Factory class """

    class LocalPhase(vertica_sdk.TransformFunctionPhase):
        """ Phase 1 """
        def getReturnType(self, server_interface, input_types, output_types):
            # sanity check
            number_of_cols = input_types.getColumnCount()
            if (number_of_cols != 1 or not input_types.getColumnType(0).isFloat()):
                raise ValueError("Function only accepts one argument (FLOAT))")

            output_types.addFloat("local_sum");
            output_types.addInt("ignored_rows");
            output_types.addInt("total_rows");

        def createTransformFunction(cls, server_interface):
            return LocalCalculation()

第二阶段 GlobalPhase 不会检查其输入,因为第一阶段已经进行了检查。与第一阶段一样,createTransformFunction 只是构造并返回对应的 TransformFunction


    class GlobalPhase(vertica_sdk.TransformFunctionPhase):
        """ Phase 2 """
        def getReturnType(self, server_interface, input_types, output_types):
            output_types.addFloat("average");
            output_types.addInt("ignored_rows");
            output_types.addInt("total_rows");

        def createTransformFunction(cls, server_interface):
            return GlobalCalculation()

在定义 TransformFunctionPhase 子类之后,工厂会将其实例化并在 getPhases() 中将其链接在一起。

    ph0Instance = LocalPhase()
    ph1Instance = GlobalPhase()

    def getPhases(cls, server_interface):
        cls.ph0Instance.setPrepass()
        phases = [cls.ph0Instance, cls.ph1Instance]
        return phases