NeuralForecast 推理 - 最简单的推理方式

发布于:2024-06-04 ⋅ 阅读:(82) ⋅ 点赞:(0)

NeuralForecast 推理 - 最简单的推理方式

flyfish

最简单的保存和加载模型代码

import pandas as pd
import numpy as np

AirPassengers = np.array(
    [112.0, 118.0, 132.0, 129.0, 121.0, 135.0, 148.0, 148.0, 136.0, 119.0],
    dtype=np.float32,
)

AirPassengersDF = pd.DataFrame(
    {
        "unique_id": np.ones(len(AirPassengers)),
        "ds": pd.date_range(
            start="1949-01-01", periods=len(AirPassengers), freq=pd.offsets.MonthEnd()
        ),
        "y": AirPassengers,
    }
)

Y_df = AirPassengersDF
Y_df = Y_df.reset_index(drop=True)
Y_df.head()
#Model Training

from neuralforecast.core import NeuralForecast
from neuralforecast.models import NBEATS

horizon = 2
models = [NBEATS(input_size=2 * horizon, h=horizon, max_steps=50)]

nf = NeuralForecast(models=models, freq='M')
nf.fit(df=Y_df)


#Save models
nf.save(path='./checkpoints/test_run/',
        model_index=None, 
        overwrite=True,
        save_dataset=True)

#Load models
nf2 = NeuralForecast.load(path='./checkpoints/test_run/')
Y_hat_df = nf2.predict().reset_index()
Y_hat_df.head()

简单的预测

import numpy as np
from neuralforecast.core import NeuralForecast
from neuralforecast.models import NBEATS

# 新的输入数据
new_data = pd.DataFrame(
    {
        "unique_id": [1.0, 1.0],
        "ds": pd.to_datetime(["1949-01-31", "1949-02-28"]),
        "y": [112.0, 118.0],
    }
)

# 确保数据的顺序和索引是正确的
new_data = new_data.reset_index(drop=True)
print("New input data:")
print(new_data)

# 加载已保存的模型
nf2 = NeuralForecast.load(path='./checkpoints/test_run/')

# 使用已加载的模型进行预测
Y_hat_df = nf2.predict(df=new_data).reset_index()
print("Prediction results:")
print(Y_hat_df)

.reset_index() 的作用如下:

重置索引:将 DataFrame 的索引重置为默认的整数索引。默认情况下,DataFrame 的索引可以是行标签,但有时候需要将其重置为默认的整数索引。
转换索引为列:如果索引是有意义的数据,可以选择将索引转换为 DataFrame 的一列数据。

.reset_index() 方法有几个常用参数:
drop:布尔值。如果为 True,则会删除索引列而不是将其转换为数据列。
inplace:布尔值。如果为 True,则会在原地修改 DataFrame 而不是返回一个新的 DataFrame。

日期索引被重置为默认的整数索引,并且原来的索引变成了 DataFrame 的一列

示例代码

import pandas as pd

data = {
    'value': [10, 20, 30, 40]
}
index = pd.date_range(start='2022-01-01', periods=4, freq='D')
df = pd.DataFrame(data, index=index)
print("Original DataFrame:")
print(df)


df_reset = df.reset_index()
print("\nDataFrame after reset_index:")
print(df_reset)

结果

Original DataFrame:
            value
2022-01-01     10
2022-01-02     20
2022-01-03     30
2022-01-04     40

DataFrame after reset_index:
       index  value
0 2022-01-01     10
1 2022-01-02     20
2 2022-01-03     30
3 2022-01-04     40