TensorFlow Serving | SavedModel Warmup 模型预热篇

作者: Anoyi

打开App

概述

TensorFlow 运行时具有延迟初始化的组件,这可能导致加载后发送给模型的第一个请求的等待时间较长,此延迟可能比单个推理请求的延迟高几个数量级。为了减少初始化的延迟对请求的影响,可以在模型加载时通过提供一组推理请求样本和 SavedModel 来触发子系统和组件的初始化,此过程称为 “预热” 模型。

使用方式

Regress、Classify、MultiInference、Predict 支持模型预热,要在加载时触发模型预热,需要在 SavedModel 目录的 assets.extra 子文件夹下附加一个预热数据文件。

模型预热正常工作的要求:

  • 预热文件名称:tf_serving_warmup_requests
  • 文件路径:assets.extra/
  • 文件格式:PredictionLog
  • 预热记录数 <= 1000
  • 必须使用相对真实的请求生成预热数据

示例预热文件存放路径: models/saved_model_half_plus_two_cpu/1/assets.extra/tf_serving_warmup_requests

预热文件生成

import tensorflow as tf
from tensorflow_serving.apis import classification_pb2
from tensorflow_serving.apis import inference_pb2
from tensorflow_serving.apis import model_pb2
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_log_pb2
from tensorflow_serving.apis import regression_pb2

def main():
    with tf.compat.v1.python_io.TFRecordWriter("tf_serving_warmup_requests") as writer:
        # 将 <request> 替换为下面的其中一个:
        # predict_pb2.PredictRequest(..)
        # classification_pb2.ClassificationRequest(..)
        # regression_pb2.RegressionRequest(..)
        # inference_pb2.MultiInferenceRequest(..)
        log = prediction_log_pb2.PredictionLog(predict_log=prediction_log_pb2.PredictLog(request=<request>))
        writer.write(log.SerializeToString())

if __name__ == "__main__":
    main()
看法

看法

昵称
邮箱