Skip to contents

The permutation filter randomly permutes the values of a single feature in a mlr3::Task to break the association with the response. The permuted feature, together with the unmodified features, is used to perform a mlr3::resample(). The permutation filter score is the difference between the aggregated performance of the mlr3::Measure and the performance estimated on the unmodified mlr3::Task.

Parameters

standardize

logical(1)
Standardize feature importance by maximum score.

nmc

integer(1)

Number of Monte-Carlo iterations to use in computing the feature importance.

Super classes

mlr3filters::Filter -> mlr3filters::FilterLearner -> FilterPermutation

Public fields

learner

(mlr3::Learner)

resampling

(mlr3::Resampling)

measure

(mlr3::Measure)

Active bindings

hash

(character(1))
Hash (unique identifier) for this object.

phash

(character(1))
Hash (unique identifier) for this partial object, excluding some components which are varied systematically during tuning (parameter values) or feature selection (feature names).

Methods

Inherited methods


Method new()

Create a FilterPermutation object.

Usage

FilterPermutation$new(
  learner = mlr3::lrn("classif.featureless"),
  resampling = mlr3::rsmp("holdout"),
  measure = NULL
)

Arguments

learner

(mlr3::Learner)
mlr3::Learner to use for model fitting.

resampling

(mlr3::Resampling)
mlr3::Resampling to be used within resampling.

measure

(mlr3::Measure)
mlr3::Measure to be used for evaluating the performance.


Method clone()

The objects of this class are cloneable with this method.

Usage

FilterPermutation$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

if (requireNamespace("rpart")) {
  learner = mlr3::lrn("classif.rpart")
  resampling = mlr3::rsmp("holdout")
  measure = mlr3::msr("classif.acc")
  filter = flt("permutation", learner = learner, measure = measure, resampling = resampling,
    nmc = 2)
  task = mlr3::tsk("iris")
  filter$calculate(task)
  as.data.table(filter)
}
#>         feature score
#>          <char> <num>
#> 1: Sepal.Length   1.0
#> 2:  Sepal.Width   0.4
#> 3: Petal.Length   0.2
#> 4:  Petal.Width   0.0

if (mlr3misc::require_namespaces(c("mlr3pipelines", "rpart"), quietly = TRUE)) {
  library("mlr3pipelines")
  task = mlr3::tsk("iris")

  # Note: `filter.frac` is selected randomly and should be tuned.

  graph = po("filter", filter = flt("permutation", nmc = 2), filter.frac = 0.5) %>>%
    po("learner", mlr3::lrn("classif.rpart"))

  graph$train(task)
}
#> $classif.rpart.output
#> NULL
#>