このコードにはバグがある

今日はこのようなツイートを見た。


#include <stdio.h>

class base_class {
public:
    base_class() { x = 0; y = 0; }

public:
    int x;
    int y;
};

class derived_class : public base_class {
public:
    derived_class() { z = 0; }

public:
    int z;
};

void calc_class_members(base_class *b, int array_size)
{
    for(int i = 0; i < array_size; i++)
    {
        b[i].x = i;
        b[i].y = i;
    }
}

// Calling from external
int update_class_members(
    derived_class * class_array,    // Valid array
    int num_array )
{
    if( class_array == NULL )
        return -1;    // means error

    calc_class_members(class_array, num_array);

    return 0;// success
}

このコードの誤りは base_class と derived_class の大きさの違いによるものであるが、取り敢えず実行してみれば明らかである。

int main()
{
    constexpr int num_array = 5;
    derived_class class_array[num_array];

    update_class_members(class_array, num_array);
    
    for(auto&& a : class_array)
        printf("%d, %d, %d\n", a.x, a.y, a.z);
    
    puts("---");
    
    printf("%p, %p, %p\n", &class_array[0]
                         , &class_array[1]
                         , &class_array[2]);
    
    printf("%p, %p, %p\n", &static_cast<base_class*>(class_array)[0]
                         , &static_cast<base_class*>(class_array)[1]
                         , &static_cast<base_class*>(class_array)[2]);
}

実行結果
http://melpon.org/wandbox/permlink/mVYWFiM7eEieWFeE

0, 0, 1
1, 2, 2
3, 3, 4
4, 0, 0
0, 0, 0
---
0x7ffd2e5e31e0, 0x7ffd2e5e31ec, 0x7ffd2e5e31f8
0x7ffd2e5e31e0, 0x7ffd2e5e31e8, 0x7ffd2e5e31f0

この通り、derived_class は base_class よりも sizeof(int) だけ大きいため、
配列の先頭要素へのポインタ class_array を base_class* へとキャストした後に添字演算を行うとずれた位置にアクセスしてしまう。

問題の答えたるバグはおそらくこれのことで間違いないだろう。

ではこのプログラムが正しい動作をするように書き直したい。

まず思いついたのは calc_class_members を template化 することであるが、
それだけでは元々のプログラムの意図が損なわれてしまう。

そのため、template の型をチェックして、base_class の派生クラスでないものは弾いてしまおう。


だがそれ以前に、このコードは生ポインタを使っていたりととてもモダンなC++とは呼べない。なので、ついでに適当に手直しをしておく。

//#include <stdio.h> // そもそもいらない
#include <array>
#include <vector>

struct base_class {
    int x = 0;
    int y = 0;
};

struct derived_class : base_class {
    int z = 0;
};

template <typename type, typename allocator>
void calc_class_members(std::vector<type, allocator>& b)
{
    static_assert(std::is_base_of<base_class, type>::value, "");

    int i = 0;
    for(type& a : b)
    {
        a.x = i;
        a.y = i;

        i++;
    }
}

template <typename allocator>
void update_class_members(std::vector<derived_class, allocator>& class_array)
{
    // nullptr チェックの必要はない
    
    calc_class_members(class_array);
}


// あるいはこうだろうか
template <typename type, size_t n>
void calc_class_members(std::array<type, n>& b)
{
    static_assert(std::is_base_of<base_class, type>::value, "");

    int i = 0;
    for(type& a : b)
    {
        a.x = i;
        a.y = i;

        i++;
    }
}

template <size_t n>
void update_class_members(std::array<derived_class, n>& class_array)
{
    calc_class_members(class_array);
}
#include <iostream>

int main()
{
    auto class_array1 = std::vector<derived_class>{ 5 };
    auto class_array2 = std::array<derived_class, 5>{};

    update_class_members(class_array1);
    update_class_members(class_array2);

    std::cout << "class_array1" << std::endl;
    for(auto&& a : class_array1)
        std::cout << a.x << ", " << a.y << ", " << a.z << std::endl;

    std::cout << "class_array2" << std::endl;
    for(auto&& a : class_array2)
        std::cout << a.x << ", " << a.y << ", " << a.z << std::endl;
}

実行結果
http://melpon.org/wandbox/permlink/COZ8R4QW2OXFM6u8

class_array1
0, 0, 0
1, 1, 0
2, 2, 0
3, 3, 0
4, 4, 0
class_array2
0, 0, 0
1, 1, 0
2, 2, 0
3, 3, 0
4, 4, 0

お見かけした他の方のコード


なるほど、
こんな手は思いつきもしなかった。