Skip to contents

The permutation filter randomly permutes the values of a single feature in a mlr3::Task to break the association with the response. The permuted feature, together with the unmodified features, is used to perform a mlr3::resample(). The permutation filter score is the difference between the aggregated performance of the mlr3::Measure and the performance estimated on the unmodified mlr3::Task.

Parameters

standardize

logical(1)
Standardize feature importance by maximum score.

nmc

integer(1)

Number of Monte-Carlo iterations to use in computing the feature importance.

Super class

mlr3filters::Filter -> FilterPermutation

Public fields

learner

(mlr3::Learner)

resampling

(mlr3::Resampling)

measure

(mlr3::Measure)

Methods

Inherited methods


Method new()

Create a FilterPermutation object.

Usage

FilterPermutation$new(
  learner = mlr3::lrn("classif.featureless"),
  resampling = mlr3::rsmp("holdout"),
  measure = NULL
)

Arguments

learner

(mlr3::Learner)
mlr3::Learner to use for model fitting.

resampling

(mlr3::Resampling)
mlr3::Resampling to be used within resampling.

measure

(mlr3::Measure)
mlr3::Measure to be used for evaluating the performance.


Method clone()

The objects of this class are cloneable with this method.

Usage

FilterPermutation$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

if (requireNamespace("rpart")) {
  learner = mlr3::lrn("classif.rpart")
  resampling = mlr3::rsmp("holdout")
  measure = mlr3::msr("classif.acc")
  filter = flt("permutation", learner = learner, measure = measure, resampling = resampling,
    nmc = 2)
  task = mlr3::tsk("iris")
  filter$calculate(task)
  as.data.table(filter)
}
#>         feature score
#> 1:  Sepal.Width     1
#> 2:  Petal.Width     0
#> 3: Sepal.Length    -3
#> 4: Petal.Length    -4

if (mlr3misc::require_namespaces(c("mlr3pipelines", "rpart"), quietly = TRUE)) {
  library("mlr3pipelines")
  task = mlr3::tsk("iris")

  # Note: `filter.frac` is selected randomly and should be tuned.

  graph = po("filter", filter = flt("permutation", nmc = 2), filter.frac = 0.5) %>>%
    po("learner", mlr3::lrn("classif.rpart"))

  graph$train(task)
}
#> $classif.rpart.output
#> NULL
#>