Skip to content

Classifiers

Bases: ClassifierMixin, NeuralNet

Implementation of the NerualNet base class. See sklx.net.NerualNet for detailed documentation.

Source code in sklx/classifier.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class NeuralNetworkClassifier(ClassifierMixin, NeuralNet):
    """
    Implementation of the NerualNet base class. See `sklx.net.NerualNet` for detailed documentation.
    """

    module = None
    max_epochs = 10
    lr = 0.1
    batch_size = 10
    optimizer = optimizers.SGD

    def __init__(
        self,
        module: nn.Module,
        max_epochs: float,
        lr: float,
        criterion: Callable,
        classes=None,
    ) -> None:
        super().__init__(
            module=module, criterion=criterion, lr=lr, max_epochs=max_epochs
        )
        self.classes = classes

    @property
    def classes_(self):
        return np.array(self.classes)