|
import streamlit as st |
|
import pandas as pd |
|
import torch |
|
from chronos import ChronosPipeline |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
@st.cache_resource |
|
def load_pipeline(): |
|
pipeline = ChronosPipeline.from_pretrained( |
|
"amazon/chronos-t5-small", |
|
device_map="cpu", |
|
torch_dtype=torch.float32, |
|
) |
|
return pipeline |
|
|
|
pipeline = load_pipeline() |
|
|
|
|
|
st.title("Time Series Forecasting Demo with Deep Learning models") |
|
st.write("This demo uses the ChronosPipeline model for time series forecasting.") |
|
|
|
|
|
default_data = """ |
|
112, 118, 132, 129, 121, 135, 148, 148, 136, 119, 104, 118, 115, 126, 141, 135, 125, 149, 170, 170, 158, |
|
133, 114, 140, 145, 150, 178, 163, 172, 178, 199, 199, 184, 162, 146, 166, 171, 180, 193, 181, 183, 218, |
|
230, 242, 209, 191, 172, 194, 196, 196, 236, 235, 229, 243, 264, 272, 237, 211, 180, 201, 204, 188, 235, |
|
227, 234, 264, 302, 293, 259, 229, 203, 229, 242, 233, 267, 269, 270, 315, 364, 347, 312, 274, 237, 278, |
|
284, 277, 317, 313, 318, 374, 413, 405, 355, 306, 271, 306, 315, 301, 356, 348, 355, 422, 465, 467, 404, |
|
347, 305, 336, 340, 318, 362, 348, 363, 435, 491, 505, 404, 359, 310, 337, 360, 342, 406, 396, 420, 472, |
|
548, 559, 463, 407, 362, 405, 417, 391, 419, 461, 472, 535, 622, 606, 508, 461, 390, 432 |
|
""" |
|
|
|
|
|
user_input = st.text_area( |
|
"Enter time series data (comma-separated values):", |
|
default_data.strip() |
|
) |
|
|
|
|
|
def process_input(input_str): |
|
return [float(x.strip()) for x in input_str.split(",")] |
|
|
|
try: |
|
time_series_data = process_input(user_input) |
|
except ValueError: |
|
st.error("Please make sure all values are numbers, separated by commas.") |
|
time_series_data = [] |
|
|
|
|
|
prediction_length = st.slider("Select Forecast Horizon (Months)", min_value=1, max_value=64, value=12) |
|
|
|
|
|
if time_series_data: |
|
|
|
context = torch.tensor(time_series_data, dtype=torch.float32) |
|
|
|
|
|
forecast = pipeline.predict( |
|
context=context, |
|
prediction_length=prediction_length, |
|
num_samples=20, |
|
) |
|
|
|
|
|
forecast_index = range(len(time_series_data), len(time_series_data) + prediction_length) |
|
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) |
|
|
|
|
|
plt.figure(figsize=(8, 4)) |
|
plt.plot(time_series_data, color="royalblue", label="Historical data") |
|
plt.plot(forecast_index, median, color="tomato", label="Median forecast") |
|
plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval") |
|
plt.legend() |
|
plt.grid() |
|
|
|
|
|
st.pyplot(plt) |
|
|
|
|
|
st.write("### Notes") |
|
st.write("For comments, feedback, or any questions, please reach out to me on [LinkedIn](https://www.linkedin.com/in/mjdarvishi/).") |