Skip to content

Commit c752be3

Browse files
committed
Fix dimension mismatch in numerical comparison
- Create common time grid (union of direct and indirect grids) - Interpolate both solutions on common grid before L2 comparison - Use t_common for all L2 norm calculations
1 parent 8793e52 commit c752be3

1 file changed

Lines changed: 25 additions & 15 deletions

File tree

docs/src/tutorial-goddard.md

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -430,18 +430,28 @@ plot!(plt, indirect_sol; label="indirect", color=2)
430430

431431
function print_numerical_comparisons(direct_sol, indirect_sol)
432432

433-
# get relevant data from direct solution
433+
# get time grids
434434
t_dir = time_grid(direct_sol)
435-
x_dir = state(direct_sol).(t_dir)
436-
u_dir = control(direct_sol).(t_dir)
435+
t_ind = time_grid(indirect_sol)
436+
437+
# create common time grid (union of both grids)
438+
t_common = unique(sort([t_dir..., t_ind...]))
439+
440+
# interpolate both solutions on common grid
441+
x_dir_func = state(direct_sol)
442+
u_dir_func = control(direct_sol)
443+
x_ind_func = state(indirect_sol)
444+
u_ind_func = control(indirect_sol)
445+
446+
x_dir = x_dir_func.(t_common)
447+
u_dir = u_dir_func.(t_common)
448+
x_ind = x_ind_func.(t_common)
449+
u_ind = u_ind_func.(t_common)
450+
437451
v_dir = variable(direct_sol)
438452
o_dir = objective(direct_sol)
439453
i_dir = iterations(direct_sol)
440454

441-
# get relevant data from indirect solution
442-
t_ind = time_grid(indirect_sol)
443-
x_ind = state(indirect_sol).(t_ind)
444-
u_ind = control(indirect_sol).(t_ind)
445455
v_ind = variable(indirect_sol)
446456
o_ind = objective(indirect_sol)
447457

@@ -457,20 +467,20 @@ plot!(plt, indirect_sol; label="indirect", color=2)
457467
# States
458468
println("├─ States (L2 Norms)")
459469
for i in eachindex(x_vars)
460-
xi_dir = [x_dir[k][i] for k in eachindex(t_dir)]
461-
xi_ind = [x_ind[k][i] for k in eachindex(t_ind)]
462-
L2_ae = L2_norm(t_dir, xi_dir - xi_ind)
463-
L2_re = L2_ae / (0.5 * (L2_norm(t_dir, xi_dir) + L2_norm(t_dir, xi_ind)))
470+
xi_dir = [x_dir[k][i] for k in eachindex(t_common)]
471+
xi_ind = [x_ind[k][i] for k in eachindex(t_common)]
472+
L2_ae = L2_norm(t_common, xi_dir - xi_ind)
473+
L2_re = L2_ae / (0.5 * (L2_norm(t_common, xi_dir) + L2_norm(t_common, xi_ind)))
464474
@printf("│ %-6s Abs: %.3e Rel: %.3e\\n", x_vars[i], L2_ae, L2_re)
465475
end
466476

467477
# Controls
468478
println("├─ Controls (L2 Norms)")
469479
for i in eachindex(u_vars)
470-
ui_dir = [u_dir[k][i] for k in eachindex(t_dir)]
471-
ui_ind = [u_ind[k][i] for k in eachindex(t_ind)]
472-
L2_ae = L2_norm(t_dir, ui_dir - ui_ind)
473-
L2_re = L2_ae / (0.5 * (L2_norm(t_dir, ui_dir) + L2_norm(t_dir, ui_ind)))
480+
ui_dir = [u_dir[k][i] for k in eachindex(t_common)]
481+
ui_ind = [u_ind[k][i] for k in eachindex(t_common)]
482+
L2_ae = L2_norm(t_common, ui_dir - ui_ind)
483+
L2_re = L2_ae / (0.5 * (L2_norm(t_common, ui_dir) + L2_norm(t_common, ui_ind)))
474484
@printf("│ %-6s Abs: %.3e Rel: %.3e\\n", u_vars[i], L2_ae, L2_re)
475485
end
476486

0 commit comments

Comments
 (0)