Skip to content

Commit 74695a1

Browse files
authored
DIVFM (#42)
1 parent b6ec63c commit 74695a1

25 files changed

Lines changed: 1779 additions & 303 deletions

.vscode/launch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"args": [
1515
"-x",
1616
"-vvv",
17-
"quantflow_tests/test_options_pricer.py",
17+
"quantflow_tests/test_divfm.py",
1818
]
1919
},
2020
]

app/gaussian_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import marimo
1+
import marimo
22

33
__generated_with = "0.19.7"
44
app = marimo.App(width="medium")

app/heston_divfm_fit.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import marimo
2+
3+
__generated_with = "0.22.0"
4+
app = marimo.App(width="medium")
5+
6+
7+
@app.cell
8+
def _():
9+
import marimo as mo
10+
from app.utils import nav_menu
11+
nav_menu()
12+
return (mo,)
13+
14+
15+
@app.cell(hide_code=True)
16+
def _(mo):
17+
mo.md(r"""
18+
# Deep Implied Volatility Factor Model
19+
""")
20+
return
21+
22+
23+
@app.cell
24+
def _():
25+
import numpy as np
26+
import torch
27+
28+
from quantflow.options.divfm.network import DIVFMNetwork
29+
from quantflow.options.divfm.trainer import DayData, DIVFMTrainer
30+
from quantflow.options.pricer import OptionPricer
31+
from quantflow.sp.heston import HestonJ
32+
from quantflow.utils.distributions import DoubleExponential
33+
34+
# ---------------------------------------------------------------------------
35+
# Grid settings
36+
# ---------------------------------------------------------------------------
37+
38+
TTM_GRID = [0.1, 0.25, 0.5, 1.0, 2.0]
39+
MAX_MONEYNESS_TTM = 1.5 # moneyness_ttm range for sampling and pricing
40+
N_PER_TTM = 20 # random options sampled per TTM per day
41+
42+
# ---------------------------------------------------------------------------
43+
# HestonJ parameter ranges (uniform sampling)
44+
# ---------------------------------------------------------------------------
45+
46+
PARAM_RANGES = {
47+
"vol": (0.10, 0.70),
48+
"rho": (-0.80, 0.10),
49+
"kappa": (0.50, 5.00),
50+
"sigma": (0.20, 1.50),
51+
"jump_fraction": (0.1, 0.50),
52+
"jump_asymmetry": (-0.50, 0.50),
53+
}
54+
55+
56+
# ---------------------------------------------------------------------------
57+
# Fixture generation
58+
# ---------------------------------------------------------------------------
59+
60+
61+
def _make_pricer(rng: np.random.Generator) -> OptionPricer:
62+
"""Sample a random HestonJ parameter set and return a ready pricer."""
63+
vol = float(rng.uniform(*PARAM_RANGES["vol"]))
64+
rho = float(rng.uniform(*PARAM_RANGES["rho"]))
65+
kappa = float(rng.uniform(*PARAM_RANGES["kappa"]))
66+
sigma = float(rng.uniform(*PARAM_RANGES["sigma"]))
67+
jump_fraction = float(rng.uniform(*PARAM_RANGES["jump_fraction"]))
68+
jump_asymmetry = float(rng.uniform(*PARAM_RANGES["jump_asymmetry"]))
69+
sv = sigma/vol
70+
kappa = max(kappa, 0.6*sv*sv)
71+
72+
model = HestonJ.create(
73+
DoubleExponential,
74+
vol=vol,
75+
kappa=kappa,
76+
rho=rho,
77+
sigma=sigma,
78+
jump_fraction=jump_fraction,
79+
jump_asymmetry=jump_asymmetry,
80+
)
81+
return OptionPricer(model=model, max_moneyness_ttm=MAX_MONEYNESS_TTM)
82+
83+
84+
def _sample_day(rng: np.random.Generator, pricer: OptionPricer) -> DayData | None:
85+
"""Price options at random (moneyness_ttm, ttm) points and return DayData.
86+
87+
Returns None if all points are invalid (e.g. numerical pricing failure).
88+
"""
89+
m_list: list[np.ndarray] = []
90+
t_list: list[np.ndarray] = []
91+
iv_list: list[np.ndarray] = []
92+
93+
for ttm in TTM_GRID:
94+
mat = pricer.maturity(ttm)
95+
m_ttm = rng.uniform(-MAX_MONEYNESS_TTM, MAX_MONEYNESS_TTM, N_PER_TTM).astype(
96+
np.float32
97+
)
98+
moneyness = m_ttm * np.sqrt(ttm)
99+
ivs = np.interp(moneyness, mat.moneyness, mat.implied_vols)
100+
101+
# drop any degenerate points (NaN / non-positive IV)
102+
valid = np.isfinite(ivs) & (ivs > 0)
103+
if not valid.any():
104+
continue
105+
106+
m_list.append(m_ttm[valid])
107+
t_list.append(np.full(valid.sum(), ttm, dtype=np.float32))
108+
iv_list.append(ivs[valid].astype(np.float64))
109+
110+
if not m_list:
111+
return None
112+
113+
return DayData(
114+
moneyness_ttm=np.concatenate(m_list),
115+
ttm=np.concatenate(t_list),
116+
implied_vols=np.concatenate(iv_list),
117+
)
118+
119+
120+
def generate_fixtures(
121+
num_days: int = 300,
122+
seed: int = 42,
123+
verbose: bool = True,
124+
) -> list[DayData]:
125+
"""Generate *num_days* synthetic IV days from random HestonJ parameters.
126+
127+
Each day is a different random parameter set, giving the DIVFM model a
128+
diverse training distribution that covers varying vol levels, skews, and
129+
term structures.
130+
"""
131+
rng = np.random.default_rng(seed)
132+
days: list[DayData] = []
133+
skipped = 0
134+
135+
for i in range(num_days):
136+
pricer = _make_pricer(rng)
137+
day = _sample_day(rng, pricer)
138+
if day is None:
139+
skipped += 1
140+
else:
141+
days.append(day)
142+
143+
if verbose and (i + 1) % 50 == 0:
144+
print(f" generated {i + 1}/{num_days} parameter sets ({len(days)} valid)")
145+
146+
if verbose:
147+
print(f"Fixture generation done: {len(days)} valid days, {skipped} skipped")
148+
149+
return days
150+
151+
152+
def fit_divfm(
153+
days: list[DayData],
154+
num_factors: int = 5,
155+
hidden_size: int = 32,
156+
num_hidden_layers: int = 3,
157+
lr: float = 1e-3,
158+
batch_days: int = 32,
159+
num_steps: int = 500,
160+
val_fraction: float = 0.1,
161+
seed: int = 0,
162+
log_every: int = 50,
163+
) -> tuple[DIVFMNetwork, list[float]]:
164+
"""Train a DIVFMNetwork on the given days.
165+
166+
Splits days into train/val, trains the network, and returns the trained
167+
network together with the per-step training losses.
168+
"""
169+
torch.manual_seed(seed)
170+
171+
n_val = max(1, int(len(days) * val_fraction))
172+
train_days = days[n_val:]
173+
val_days = days[:n_val]
174+
175+
net = DIVFMNetwork(
176+
num_factors=num_factors,
177+
hidden_size=hidden_size,
178+
num_hidden_layers=num_hidden_layers,
179+
)
180+
trainer = DIVFMTrainer(net, lr=lr, batch_days=batch_days)
181+
182+
print(
183+
f"Training DIVFM factors={num_factors} hidden={hidden_size}"
184+
f" layers={num_hidden_layers} lr={lr}"
185+
f" batch_days={batch_days} steps={num_steps}"
186+
)
187+
print(f" train days: {len(train_days)} val days: {len(val_days)}")
188+
189+
losses = trainer.fit(
190+
train_days,
191+
num_steps=num_steps,
192+
val_days=val_days,
193+
log_every=log_every,
194+
)
195+
196+
val_loss = trainer.evaluate(val_days)
197+
print(f"Final val loss: {val_loss:.6f}")
198+
199+
return net, losses
200+
201+
202+
return fit_divfm, generate_fixtures, np, torch
203+
204+
205+
@app.cell
206+
def _(generate_fixtures):
207+
days = generate_fixtures(num_days=300, seed=42)
208+
return (days,)
209+
210+
211+
@app.cell
212+
def _(days, fit_divfm):
213+
net, losses = fit_divfm(days, num_steps=500, log_every=50)
214+
return (net,)
215+
216+
217+
@app.cell
218+
def _():
219+
return
220+
221+
222+
@app.cell
223+
def _(mo, net, np, torch):
224+
import plotly.graph_objects as go
225+
226+
# 1. Create the coordinate grid
227+
m_range = np.linspace(-1.5, 1.5, 40) # moneyness_ttm
228+
t_range = np.linspace(0.1, 2.0, 40) # ttm
229+
M, T = np.meshgrid(m_range, t_range)
230+
231+
# Flatten the grid to feed into the neural network
232+
M_flat = M.flatten()
233+
T_flat = T.flatten()
234+
235+
# Prepare inputs for the network
236+
M_tensor = torch.tensor(M_flat, dtype=torch.float32)
237+
T_tensor = torch.tensor(T_flat, dtype=torch.float32)
238+
239+
# 2. Evaluate the network to get the factors
240+
with torch.no_grad():
241+
factors_pred = net(M_tensor, T_tensor).numpy()
242+
243+
# 3. Create a Plotly figure for factors 1, 2, 3, and 4
244+
tabs_dict = {}
245+
for i in range(1, 5):
246+
# Reshape the 1D factor output back into the 2D grid shape
247+
Z = factors_pred[:, i].reshape(M.shape)
248+
249+
fig = go.Figure(data=[go.Surface(x=M, y=T, z=Z, colorscale='Viridis')])
250+
251+
fig.update_layout(
252+
title=f"DIVFM Learned Factor {i}",
253+
scene=dict(
254+
xaxis_title='Moneyness / √TTM',
255+
yaxis_title='Time to Maturity',
256+
zaxis_title=f'Factor {i} Value',
257+
camera=dict(eye=dict(x=1.8, y=1.8, z=0.8)),
258+
dragmode="turntable"
259+
),
260+
margin=dict(l=0, r=0, b=0, t=40)
261+
)
262+
263+
tabs_dict[f"Factor {i}"] = fig
264+
265+
# 4. Display them in an interactive tabbed interface
266+
mo.ui.tabs(tabs_dict)
267+
return
268+
269+
270+
@app.cell
271+
def _():
272+
return
273+
274+
275+
if __name__ == "__main__":
276+
app.run()

0 commit comments

Comments
 (0)