tidymodels / parsnip Integration

ggmlR registers a "ggml" engine for parsnip::mlp(), giving you GPU-accelerated neural networks inside the tidymodels ecosystem — resampling, tuning, workflows, and recipes all work out of the box.

library(ggmlR)
library(parsnip)

1. Classification

spec <- mlp(
  hidden_units = c(64L, 32L),
  epochs       = 20L,
  dropout      = 0.1
) |>
  set_engine("ggml") |>
  set_mode("classification")

fit_obj <- fit(spec, Species ~ ., data = iris)

# Class predictions
preds <- predict(fit_obj, new_data = iris)
head(preds)

# Probability predictions
probs <- predict(fit_obj, new_data = iris, type = "prob")
head(probs)

# Accuracy
cat(sprintf("Accuracy: %.4f\n", mean(preds$.pred_class == iris$Species)))

2. Regression

spec_reg <- mlp(
  hidden_units = c(64L, 32L),
  epochs       = 50L
) |>
  set_engine("ggml") |>
  set_mode("regression")

fit_reg <- fit(spec_reg, mpg ~ ., data = mtcars)

preds_reg <- predict(fit_reg, new_data = mtcars)
head(preds_reg)

3. Engine parameters

The ggml engine maps standard parsnip arguments to ggmlR internals:

parsnip ggmlR Default
hidden_units hidden_layers c(128, 64)
epochs epochs 10
dropout dropout 0.2
activation activation "relu"
learn_rate learn_rate 0.001
# Customize architecture
spec_custom <- mlp(
  hidden_units = c(128L, 64L, 32L),
  epochs       = 30L,
  dropout      = 0.3,
  activation   = "relu"
) |>
  set_engine("ggml") |>
  set_mode("classification")

4. Resampling with rsample

library(rsample)

folds <- vfold_cv(iris, v = 5L)

spec <- mlp(hidden_units = c(32L), epochs = 10L) |>
  set_engine("ggml") |>
  set_mode("classification")

library(tune)
library(yardstick)
library(workflows)

wf <- workflow() |>
  add_model(spec) |>
  add_formula(Species ~ .)

results <- fit_resamples(wf, resamples = folds)
collect_metrics(results)

5. Recipes for preprocessing

ggmlR accepts only numeric features. Use recipes to handle factors, missing values, and scaling:

library(recipes)
library(workflows)

rec <- recipe(Species ~ ., data = iris) |>
  step_normalize(all_numeric_predictors())

spec <- mlp(hidden_units = c(32L), epochs = 10L) |>
  set_engine("ggml") |>
  set_mode("classification")

wf <- workflow() |>
  add_recipe(rec) |>
  add_model(spec)

fit_obj <- fit(wf, data = iris)
predict(fit_obj, new_data = iris)

For datasets with factors:

rec <- recipe(Species ~ ., data = iris) |>
  step_dummy(all_nominal_predictors()) |>
  step_normalize(all_numeric_predictors())

6. Hyperparameter tuning

library(tune)
library(dials)
library(workflows)

spec <- mlp(
  hidden_units = tune(),
  epochs       = tune(),
  dropout      = tune()
) |>
  set_engine("ggml") |>
  set_mode("classification")

wf <- workflow() |>
  add_model(spec) |>
  add_formula(Species ~ .)

grid <- grid_regular(
  hidden_units(range = c(16L, 128L)),
  epochs(range = c(10L, 50L)),
  dropout(range = c(0, 0.4)),
  levels = 3L
)

folds <- vfold_cv(iris, v = 3L)
results <- tune_grid(wf, resamples = folds, grid = grid)
show_best(results, metric = "accuracy")

7. Comparison with other engines

library(workflows)
library(workflowsets)

specs <- workflow_set(
  preproc = list(basic = Species ~ .),
  models  = list(
    ggml  = mlp(hidden_units = c(32L), epochs = 20L) |> set_engine("ggml"),
    nnet  = mlp(hidden_units = 32L,    epochs = 200L) |> set_engine("nnet")
  )
) |>
  workflow_map("fit_resamples",
               resamples = vfold_cv(iris, v = 5L))

rank_results(specs, rank_metric = "accuracy")

8. Extracting the fitted engine and fit time

After fitting a workflow or parsnip model you can reach the native ggmlR object and the training time with the standard hardhat/tune extractors.

spec <- mlp(hidden_units = c(16L), epochs = 10L) |>
  set_engine("ggml") |>
  set_mode("classification")

fit_obj <- fit(spec, Species ~ ., data = iris)

# The native ggmlR engine object (class "ggmlr_parsnip_model").
# extract_fit_*() are re-exported by parsnip (originally from hardhat).
eng <- parsnip::extract_fit_engine(fit_obj)
class(eng)

# Training time parsnip recorded for the fit (one-row tibble: stage_id, elapsed).
parsnip::extract_fit_time(fit_obj)

The engine object is the same one returned by ggmlR’s own fit wrappers, so all the inspection helpers work on it directly:

ggml_model_backend(eng)            # "vulkan" or "cpu" (actual backend used)
head(ggml_training_history(eng))   # per-epoch loss / accuracy curve
generics::glance(eng)              # one-row model summary
generics::tidy(eng)                # one row per layer

9. Limitations of the engine object

The fitted engine object holds a live compiled model. Keep these limits in mind, especially for resampling and tuning on a GPU.

  • Serialization / GPU state. A compiled model carries external pointers (the ggml backend and scheduler) and, on GPU, weights living in Vulkan buffers. These do not survive a plain saveRDS() / readRDS() or being shipped to a worker process as-is. For mlr3 the learner marshals the model (marshal_model() / unmarshal_model()); within tidymodels, persist via ggmlR’s own ggml_save_model() / ggml_load_model() (sequential/functional) or ag_save_model() / ag_load_model() (autograd) rather than serializing the raw engine object.

  • Parallel resampling. fit_resamples() / tune_grid() with a parallel backend serialize fitted models across workers; because of the GPU state above, prefer sequential execution (or CPU) when in doubt. In particular, avoid control_grid(parallel_over = "everything") together with a live GPU model — it can crash or return wrong results. The safe pattern is parallel_over = "resamples" (the default), or running on CPU.

  • Autograd tradepath. Autograd (ag_sequential) models are reconstructed on unmarshal from a captured (dims, hyperparameters) snapshot, not from the training task. A custom model_fn that reads the task at fit time is therefore not round-trippable through marshal (the task is NULL at rebuild). Stick to dims/parameters-driven architectures if you need marshalable autograd models.

  • One device per fit. A compiled model is bound to the backend chosen at set_engine("ggml", backend = ...) (or auto-detected). It is not migrated between CPU and GPU after fitting; re-fit to change device.


Summary

Feature Supported
Classification Yes (class, prob)
Regression Yes (numeric)
GPU (Vulkan) Yes (auto-detected)
Recipes / preprocessing Yes
Resampling Yes
Tuning Yes
Workflows Yes
extract_fit_engine() / extract_fit_time() Yes