https://github.com/JuliaDiffEq/DiffEqFlux.jl
Revision 9af1cd7313fe1b3e7d235bc9f42212961ae93a1e authored by Christopher Rackauckas on 08 April 2020, 18:36:50 UTC, committed by GitHub on 08 April 2020, 18:36:50 UTC
2 parent s 35cfab2 + 424703e
Raw File
Tip revision: 9af1cd7313fe1b3e7d235bc9f42212961ae93a1e authored by Christopher Rackauckas on 08 April 2020, 18:36:50 UTC
Merge pull request #220 from ranjanan/RA/quaddirect
Tip revision: 9af1cd7
neural_dae.jl
using Flux, DiffEqFlux, OrdinaryDiffEq, Optim

#A desired MWE for now, not a test yet.

function f(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁ + k₃*y₂*y₃
    du[2] =  k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
    du[3] =  y₁ + y₂ + y₃ - 1
    nothing
end

u₀ = [1.0, 0, 0]
M = [1. 0  0
    0  1. 0
    0  0  0]
tspan = (0.0,10.0)
p = [0.04,3e7,1e4]
func = ODEFunction(f,mass_matrix=M)
prob = ODEProblem(f,u₀,tspan,(0.04,3e7,1e4))
sol = solve(prob,Rodas5())


dudt2 = Chain(x -> x.^3,Dense(3,50,tanh),Dense(50,2))


ndae = NeuralDAE(dudt2, (u,p,t) -> [u[1] + u[2] + u[3] - 1], tspan, M, Rodas5())

ndae(u₀)

function predict_n_dae(p)
    ndae(u₀,p)
end

function loss(p)
    pred = predict_n_dae(p)
    loss = sum(abs2,sol .- pred)
    loss,pred
end

p = p .+ rand(3) .* p 

res = DiffEqFlux.sciml_train(loss, p, BFGS(initial_stepnorm = 0.0001))
back to top