TensorFlow Serving | SavedModel Warmup 模型预热篇
作者: Anoyi
打开App

概述
TensorFlow 运行时具有延迟初始化的组件,这可能导致加载后发送给模型的第一个请求的等待时间较长,此延迟可能比单个推理请求的延迟高几个数量级。为了减少初始化的延迟对请求的影响,可以在模型加载时通过提供一组推理请求样本和 SavedModel 来触发子系统和组件的初始化,此过程称为 “预热” 模型。
使用方式
Regress、Classify、MultiInference、Predict 支持模型预热,要在加载时触发模型预热,需要在 SavedModel 目录的 assets.extra
子文件夹下附加一个预热数据文件。
模型预热正常工作的要求:
- 预热文件名称:
tf_serving_warmup_requests
- 文件路径:
assets.extra/
- 文件格式:PredictionLog
- 预热记录数 <= 1000
- 必须使用相对真实的请求生成预热数据
示例预热文件存放路径: models/saved_model_half_plus_two_cpu/1/assets.extra/tf_serving_warmup_requests
预热文件生成
import tensorflow as tf
from tensorflow_serving.apis import classification_pb2
from tensorflow_serving.apis import inference_pb2
from tensorflow_serving.apis import model_pb2
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_log_pb2
from tensorflow_serving.apis import regression_pb2
def main():
with tf.compat.v1.python_io.TFRecordWriter("tf_serving_warmup_requests") as writer:
# 将 <request> 替换为下面的其中一个:
# predict_pb2.PredictRequest(..)
# classification_pb2.ClassificationRequest(..)
# regression_pb2.RegressionRequest(..)
# inference_pb2.MultiInferenceRequest(..)
log = prediction_log_pb2.PredictionLog(predict_log=prediction_log_pb2.PredictLog(request=<request>))
writer.write(log.SerializeToString())
if __name__ == "__main__":
main()
©
著作权归作者所有,转载或内容合作请联系作者
昵称
邮箱
看法