Baduit

A young French developper who really likes (modern) C++

About me
17 January 2021

Chainer les comparaisons en C++

by Baduit

Le problème initial

On a tous déjà voulu écrire un bout de code qui ressemblait à ça :

1
2
3
4
5
6
7
8
9
10
// Cas 1
if (6 > x > 2)
{
    // Do stuff
}
// Cas 2
if (x == y == z)
{
    // Some code
}

Mais, dans les deux cas, ça n’a pas eu le comportement escompté.

En effet, pour le premier, au lieu de tester si x est compris entre 2 et 6, la condition vérifie si 6 est supérieur à x puis regarder si le résultat de cette comparaison (true, qui vaut 1, ou false, qui vaut 0) est supérieur à 2, ce qui est forcément faux.

Quant au second cas, on compare x avec y, puis on compare ensuite le résultat de cette comparaison avec z au lieu de vérifier que les trois valeurs sont identiques.

Ce n’est pas seulement le cas du C++, à ma connaissance, seul Python permet chainer les comparaisons de manière native.

Démonstration

Objectif

Le but de cette démonstation sera de réussir à chainer les comparaisons de la structure suivante :

1
2
3
4
5
struct Point
{
    int x = 0;
    int y = 0;
};

Et ainsi pouvoir écrire le code suivant :

1
2
3
4
5
6
7
8
9
10
11
12
int main()
{
    Point a = {4, 3};
    Point b = {4, 3};
    Point c = {1, 8};
    Point d = {9, 8};

    if (a == b == c == d)
        ; // do something
    else
        ; // do something else
}

Bien sûr, on ne s’arrêtera pas à un maximum de 4 points à comparer, techniquement on pourra enchainer autant de comparaisons que souhaité.

La surcharge d’opérateur

En C++, il est possible de surcharger la plupart des opérateurs, dont les opérateurs de comparaison (==, !=, >, >=, <, <=). Voici un exemple :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// The overload of the operator== to be able to compare 2 points
bool operator==(const Point& a, const Point& b)
{
    return
        a.x == b.x &&
        a.y == b.y;
}

// Now we can use it normally
int main()
{
    Point a = {4, 3};
    Point b = {4, 3};
    if (a == b)
        ; // do something
}

En général, on fait retourner aux opérateurs de comparaison des booléens pour que leurs utilisations collent avec leur sémantique, mais en réalité, on peut leur faire retourner n’importe quel type :

1
2
3
4
std::string operator==(const Point& a, const Point& b)
{
    return "Don't do this, return a string is not a good idea";
}

On peut aussi comparer des objets de types différents :

1
2
3
4
5
6
bool operator==(const Point& a, const std::pair<int, int>& b)
{
    return
        a.x == b.first &&
        a.y == b.second;
}

L’idée

Lorsque l’on a effectué une comparaison entre 2 valeurs, on a 3 informations :

Pour pouvoir chainer les comparaisons, on a besoin, en plus du résultat, de la valeur de droite. Pourquoi celle de droite ? Car on effectue les comparaisons de droite à gauche, donc si on a l’expression a > b > c, on peut la traduire en (a > b) && (b > c). Ce qui nous donnera 2 sous-expressions a > b et b > c. Dans la première sous-expression, ‘a’ est la valeur de gauche et ‘b’ est la valeur de droite. Dans la 2nde sous-expression, c’est ‘b’ qui est réutilisé et non pas ‘a’.

Par conséquent, quand on va surcharger notre opérateur, on va retourner un objet contenant les 2 informations que l’on veut.

La classe résultat

On va créer une classe ComparisonResult pour stocker les informations nécessaires :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class ComparisonResult
{
    public:
        ComparisonResult(bool previous_result, const Point& right_value):
            _result(previous_result),
            _value(right_value)
        {}

    // Explicit conversion to bool
    explicit operator bool () const { return _result; }

    const Point& get_value() const { return _value; }

    private:
        bool _result;
        const Point& _value;
};

On notera que j’ai aussi rajouté l’opérateur pour pouvoir convertir notre résultat en booléen (operator bool). Si on ne l’avait pas ajouté, au aurait été obligé d’écrire un code ressemblant à ça :

1
2
if ((a == b).get_result())
    ;

Je pense qu’on est tous d’accord pour dire que ce n’est pas très lisible.

Comparer 2 points

Maintenant que notre classe pour contenir notre résultat est créée, on peut surcharger notre opérateur== entre 2 points :

1
2
3
4
ComparisonResult operator==(const Point& left, const Point& right)
{
    return ComparisonResult(left == right, right); 
}

Comparer un résultat

Il nous manque juste une dernière chose pour pouvoir atteindre notre but : pouvoir comparer notre résultat avec un autre point. On va donc surcharger l’opérateur== entre un résultat et un point :

1
2
3
4
5
6
7
8
9
10
11
ComparisonResult operator==(const ComparisonResult& result, const Point& point)
{
    if (result)
    {
        return result.get_value() == point;
    }
    else
    {
        return ComparisonResult(false, point);
    }
}

On remarquera que l’on teste le résultat de la comparaison précédente avant d’effectuer la comparaison entre les points, car si le résultat précédent était faux, le résultat final doit être faux, peu importe le résultat des comparaisons suivantes.

Par exemple, dans l’expression suivante x == y && y == z si x == y est faux, alors y == z n’est même pas évalué. Les comparaisons chainées de nos points fonctionnement exactement de la même manière.

Code final

Maintenant que l’on a tout ce qu’il faut pour que ça fonctionne, voici le code final :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <iostream>

struct Point
{
    int x = 0;
    int y = 0;
};

class ComparisonResult
{
    public:
        ComparisonResult(bool previous_result, const Point& right_value):
            _result(previous_result),
            _value(right_value)
        {}

    // Explicit conversion to bool
    explicit operator bool () const { return _result; }

    const Point& get_value() const { return _value; }

    private:
        bool _result;
        const Point& _value;
};

ComparisonResult operator==(const Point& left, const Point& right)
{
    bool result =
        left.x == right.x &&
        left.y == right.y;
    return ComparisonResult(result, right); 
}

ComparisonResult operator==(const ComparisonResult& result, const Point& point)
{
    if (result)
    {
        return result.get_value() == point;
    }
    else
    {
        return ComparisonResult(false, point);
    }
}

// Now an example of use
int main()
{
    {
        Point a = {4, 3};
        Point b = {4, 3};
        Point c = {4, 3};
        Point d = {4, 3};

        if (a == b == c == d)
            std::cout << "1: They are all equal" << std::endl;
        else
            std::cout << "1 :At least one of them is different" << std::endl;
    }

    {
        Point a = {4, 3};
        Point b = {4, 8};
        Point c = {4, 3};
        Point d = {4, 3};

        if (a == b == c == d)
            std::cout << "2: They are all equal" << std::endl;
        else
            std::cout << "2: At least one of them is different" << std::endl;
    }
}

Vous pouvez tester ce code en ligne ici.

Aller plus loin

On peut surcharger les autres opérateurs de la même manière, et ainsi pouvoir chainer n’importe quelles comparaisons.

On pourrait aussi imaginer pouvoir faire quelque chose de similaire dans d’autres langages supportant la surcharge d’opérateurs.

Exemple d’implémentation

J’ai fait une bibliothèque C++ qui réutilise le principe et permet de chainer les comparaisons facilement pour n’importe quel type, et impose aussi quelques restrictions pour de pas pouvoir écrire des horreurs de ce style-là : x > y != 3 == 7 (qui n’auraient aucun sens).

Elle contient en plus quelques alias et littéraux pour les types de bases (les nombres et strings): https://github.com/Baduit/Croissant

tags: C++ - Modern C++