https://github.com/JuliaDiffEq/DiffEqFlux.jl
Revision 05a3545d642163ae5d88423cc2f725723b402acd authored by Christopher Rackauckas on 17 July 2019, 02:54:28 UTC, committed by GitHub on 17 July 2019, 02:54:28 UTC
1 parent e522ff0
Raw File
Tip revision: 05a3545d642163ae5d88423cc2f725723b402acd authored by Christopher Rackauckas on 17 July 2019, 02:54:28 UTC
Update README.md
Tip revision: 05a3545
odenet.jl
using DiffEqFlux, Flux, Test, OrdinaryDiffEq 
using Statistics
#= using Plots =#

## True Solution 
u0 = [2.; 0.]
datasize = 30
tspan = (0.0,25.0)

function trueODEfunc(du,u,p,t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end

true_prob = ODEProblem(trueODEfunc, u0,tspan)
true_sol = solve(true_prob,Tsit5(),saveat=range(tspan[1],tspan[2],length=datasize))

#= true_sol_plot = solve(true_prob,Tsit5()) =#
#= plot(true_sol_plot) =#

## Neural ODE 
dudt = Chain(Dense(2,50,tanh),Dense(50,2))

function ODEfunc(du,u,p,t)
    du .= Flux.data(dudt(u))
end

pred_prob = ODEProblem(ODEfunc, u0,tspan)
pred_sol = solve(pred_prob,Tsit5(),saveat=range(tspan[1],tspan[2],length=datasize))

## Loss
l1_loss(pred,target) = mean(abs.(pred-target))
l1_loss(pred_sol,true_sol)


back to top