struct ComputableNode{OT, AT} <: Node
op::OT
args::AT
end
mutable struct CachedNode{NT, OUT} <: Node
node::NT
out::OUT
end
function register(op, args...)
node = ComputableNode(op, args)
out = forward(node)
CachedNode(node, out)
end
function register(op, args...)
node = ComputableNode(op, args)
out = forward(node)
CachedNode(node, out)
end
import Base: +, -, *, /
+(x::Node, y::Node) = register(+, x, y)
-(x::Node, y::Node) = register(-, x, y)
*(x::Node, y::Node) = register(*, x, y)
/(x::Node, y::Node) = register(/, x, y)
import Base: sin
sin(x::Node) = register(sin, x)
forward(cached::CachedNode) = cached.out = forward(cached.node)
forward(node::ComputableNode) = forward(node.op, map(forward, node.args)...)
forward(op::Operator, args...) = op.f(args...; kwargs...)
forward(var::Variable) = var.value
f(x) = sin(x*x)
x = Variable(5.0, 0.0)
y = f(x) <-error
Próbuję zrozumieć ten algorytm, ale jest tu chyba jakiś błąd i nie potrafię go rozwiązać.
MethodError: no method matching forward(::typeof(*), ::Float64, ::Float64)
Closest candidates are:
forward(!Matched::Operator, ::Any...) at In[82]:3
Stacktrace:
[1] forward(::ComputableNode{typeof(*),Tuple{Variable{Float64},Variable{Float64}}}) at .\In[82]:2
[2] register(::Function, ::Variable{Float64}, ::Vararg{Variable{Float64},N} where N) at .\In[81]:3
[3] *(::Variable{Float64}, ::Variable{Float64}) at .\In[81]:9
[4] f(::Variable{Float64}) at .\In[192]:1
[5] top-level scope at In[192]:3
LINK: http://blog.rogerluo.me/2018/10/23/write-an-ad-in-one-day/