# -*- coding: utf-8 -*-
"""
此文件由 LuWu 自动生成
"""
import tensorflow as tf

{{data_preprocess_template}}

def build_model():
    pre_trained_net: tf.keras.Model = tf.keras.applications.{{net_name}}()
    pre_trained_net.trainable = False
    # 记录densenet
    model = tf.keras.Sequential(
        [
            pre_trained_net,
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(1000, activation="relu"),
            tf.keras.layers.Dropout(0.3),
            tf.keras.layers.Dense({{num_classes}}, activation="softmax"),
        ]
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.keras.losses.categorical_crossentropy,
        metrics=["accuracy"],
    )
    return pre_trained_net, model

# 定义编号->类别的映射
num_to_classes_map = {{num_classes_map}}
# 加载训练好的模型
model = tf.keras.models.load_model("{{model_path}}")
# 加载并预处理图片
image = read_image("/path/to/image")
# 进行预测，并输出结果
outputs = model.predict(tf.reshape(image, (-1, 224, 224, 3)))
index = int(tf.argmax(outputs[0]))
print("当前图片类别为：{}".format(num_to_classes_map[index]))