2025-03-10
Machine Learning In Elixir - Neural Network Example
Neural Network Elixir script - Nx/Axon
Section
Mix.install([
{:nx, "~> 0.9.2"},
{:exla, "~> 0.9.2"},
{:axon, "~> 0.7.0"},
{:kino, "~> 0.15.3"},
{:scidata, "~> 0.1.11"},
{:table_rex, "~> 3.1.1"}
])
Nx.default_backend(EXLA.Backend)
{Nx.BinaryBackend, []}
{images, labels} = Scidata.MNIST.download()
{image_data, image_type, image_shape} = images
{label_data, label_type, label_shape} = labels
images =
image_data
|> Nx.from_binary(image_type)
|> Nx.divide(255) # One hot encoding
|> Nx.reshape({60_000, :auto})
labels =
label_data
|> Nx.from_binary(label_type)
|> Nx.reshape(label_shape)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.iota({1, 10}))
train_range = 0..49_999//1
test_range = 50_000..-1//1
train_images = images[train_range]
train_labels = labels[train_range]
test_images = images[test_range]
test_labels = labels[test_range]
batch_size = 64
train_data =
train_images
|> Nx.to_batched(batch_size)
|> Stream.zip(Nx.to_batched(train_labels, batch_size))
test_data =
test_images
|> Nx.to_batched(batch_size)
|> Stream.zip(Nx.to_batched(test_labels, batch_size))
model =
Axon.input("images", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
template = Nx.template({1, 784}, :f32)
Axon.Display.as_graph(model, template)
graph TD;
22[/"images (:input) #Nx.Tensor<
f32[1][784]
Nx.TemplateBackend
>"/];
23["dense_0 (:dense) #Nx.Tensor<
f32[1][128]
Nx.TemplateBackend
>"];
24["relu_0 (:relu) #Nx.Tensor<
f32[1][128]
Nx.TemplateBackend
>"];
25["dense_1 (:dense) #Nx.Tensor<
f32[1][10]
Nx.TemplateBackend
>"];
26["softmax_0 (:softmax) #Nx.Tensor<
f32[1][10]
Nx.TemplateBackend
>"];
25 --> 26;
24 --> 25;
23 --> 24;
22 --> 23;
trained_model_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :sgd) # Gradient descent opt
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(train_data, %{}, epochs: 10, compiler: EXLA)
22:15:31.136 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
22:15:31.137 [warning] passing parameter map to initialization is deprecated, use %Axon.ModelState{} instead
Epoch: 0, Batch: 750, accuracy: 0.7748213 loss: 0.9494509
Epoch: 1, Batch: 750, accuracy: 0.8812001 loss: 0.6951301
Epoch: 2, Batch: 750, accuracy: 0.8963672 loss: 0.5867931
Epoch: 3, Batch: 750, accuracy: 0.9058963 loss: 0.5236613
Epoch: 4, Batch: 750, accuracy: 0.9121794 loss: 0.4809286
Epoch: 5, Batch: 750, accuracy: 0.9173602 loss: 0.4493324
Epoch: 6, Batch: 750, accuracy: 0.9221663 loss: 0.4245615
Epoch: 7, Batch: 750, accuracy: 0.9260361 loss: 0.4043231
Epoch: 8, Batch: 750, accuracy: 0.9293233 loss: 0.3872811
Epoch: 9, Batch: 750, accuracy: 0.9321946 loss: 0.3725979
#Axon.ModelState<
Parameters: 101770 (407.08 KB)
Trainable Parameters: 101770 (407.08 KB)
Trainable State: 0, (0 B)
>
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)
22:15:36.522 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Batch: 156, accuracy: 0.9394904
%{
0 => %{
"accuracy" => #Nx.Tensor<
f32
EXLA.Backend<host:0, 0.1385899118.1730805790.108597>
0.9394904375076294
>
}
}
{test_batch, _} = Enum.at(test_data, 0)
test_image = test_batch[0]
test_image
|> Nx.reshape({28, 28})
|> Nx.to_heatmap()
#Nx.Heatmap<
f32[28][28]
>
{_, predict_fn} = Axon.build(model, compiler: EXLA)
probabilities =
test_image
|> Nx.new_axis(0)
|> then(&predict_fn.(trained_model_state, &1))
probabilities |> Nx.argmax()
#Nx.Tensor<
s32
EXLA.Backend<host:0, 0.1385899118.1730805790.108626>
3
>