[Python] matplotlib の Surface (3D) プロット 座標について
2018-04-23

3D の描画でハマったのでメモ。

プロット

線(Line plot) と 分布図(Scatter plot) は ほぼ 2D と同じなので 省略。 見たい人は折りたたみを開いてみてね。

Detail

分布図と線は 2D の プロット と同じく1次元の配列を期待します。

In [1]:
%autosave 0
%matplotlib inline
Autosave disabled
In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
In [3]:
x = [7, 2, 5, 1, 7]
y = [9, 4, 6, 5, 3]
z = [3, 7, 1, 3, 5]
In [4]:
fig = plt.figure()
ax = Axes3D(fig)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.plot(x, y, z)
Out[4]:
[<mpl_toolkits.mplot3d.art3d.Line3D at 0x7fc6b7643f60>]
In [5]:
fig = plt.figure()
ax = Axes3D(fig)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.plot(x, y, z, 'ro')
Out[5]:
[<mpl_toolkits.mplot3d.art3d.Line3D at 0x7fc6b75f0438>]
In [6]:
# それぞれの点がどの座標にプロットされているかを表示してみる
pd.DataFrame([
    [xc, yc, zc]
    for r, (xc, yc, zc) in enumerate(zip(x, y, z), 1)
], columns=['x', 'y', 'z'])
Out[6]:
x y z
0 7 9 3
1 2 4 7
2 5 6 1
3 1 5 3
4 7 3 5

簡単ですね。

今回メインで説明するのは plot_surface() 関数です。

備考

plot_wireframe() は ほぼ同じです。

wireframe は 線だけで 物体を形取り surface はそれらの間に着色することで物体の形をよりリアルに表現できます。

さて、こいつが期待するのは 通常のプロットとは違い、2次元のネストしたデータです。

X, Y, Z Data values as 2D arrays

座標についてはこれだけの説明と申し訳程度の描画画像が 公式チュートリアル にあるわけですが、ちょっと何言ってるかわからなかったので自分で動作確認してようやくわかりました。

行列の 同じ行 もしくは 同じ列 の隣接した点を結ぶことで線の描画を行うというものらしいです。 行、列という情報を付加するために 2D にしたんですねぇ。

以下操作ログ。座標は適当です。

In [1]:
%autosave 0
%matplotlib inline
Autosave disabled
In [2]:
from itertools import chain
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import display
In [3]:
xs = np.array([
    [1, 2, 4, 6],
    [1, 2, 4, 6],
    [1, 2, 4, 6],
    [1, 2, 4, 4],
])
ys = np.array([
    [2, 2, 2, 2],
    [5, 5, 5, 5],
    [7, 7, 7, 7],
    [10, 10, 10, 10],
])
zs = np.array([
    [10, 4, 8, 3],
    [5, 15, 20, 30],
    [7, 14, 20, 42],
    [30, 20, 40, 60],
])
In [4]:
fig = plt.figure()
ax = Axes3D(fig)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.plot_surface(xs, ys, zs)
Out[4]:
<mpl_toolkits.mplot3d.art3d.Poly3DCollection at 0x7f22a67a12b0>
In [5]:
# それぞれの点がどの座標にプロットされているかを表示してみる
pd.DataFrame(list(chain(*[
    [[(r, c), xc, yc, zc] for c, (xc, yc, zc) in enumerate(zip(xr, yr, zr), 1)]
    for r, (xr, yr, zr) in enumerate(zip(xs, ys, zs), 1)
])), columns=['(行番号,列番号)', 'x', 'y', 'z'])
Out[5]:
(行番号,列番号) x y z
0 (1, 1) 1 2 10
1 (1, 2) 2 2 4
2 (1, 3) 4 2 8
3 (1, 4) 6 2 3
4 (2, 1) 1 5 5
5 (2, 2) 2 5 15
6 (2, 3) 4 5 20
7 (2, 4) 6 5 30
8 (3, 1) 1 7 7
9 (3, 2) 2 7 14
10 (3, 3) 4 7 20
11 (3, 4) 6 7 42
12 (4, 1) 1 10 30
13 (4, 2) 2 10 20
14 (4, 3) 4 10 40
15 (4, 4) 4 10 60

警告

plot_surface() では データが小数でないと以下のようなエラーが発生することがあるようです。

AttributeError: 'Float' object has no attribute 'dtype'

とりあえず numpy の Array であれば dtype='float' を指定するだけでOKです。

Mesh grid

3D の描画に必要なデータ構造がわかりましたが、これを毎回手動で作成するのは手間です。

NumPy には 1次元のオブジェクトを組み合わせて多次元のメッシュ構造を作成する 機能があります。

mgrid

mgrid は slice を指定することでメッシュ構造を作成するオブジェクトです。

ちょっと独特な書き方ですね。

In [1]:
%autosave 0
%matplotlib inline
Autosave disabled
In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#from IPython.display import display
In [3]:
xs, ys = np.mgrid[1:10:3, 10:20:2]
In [4]:
pd.DataFrame(xs)
Out[4]:
0 1 2 3 4
0 1 1 1 1 1
1 4 4 4 4 4
2 7 7 7 7 7
In [5]:
pd.DataFrame(ys)
Out[5]:
0 1 2 3 4
0 10 12 14 16 18
1 10 12 14 16 18
2 10 12 14 16 18
In [6]:
pd.DataFrame(xs + ys)
Out[6]:
0 1 2 3 4
0 11 13 15 17 19
1 14 16 18 20 22
2 17 19 21 23 25
In [7]:
fig = plt.figure()
ax = Axes3D(fig)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.plot_surface(xs, ys, xs + ys)
Out[7]:
<mpl_toolkits.mplot3d.art3d.Poly3DCollection at 0x7f08f9fdd6a0>

上記の例でいうと

  • xs には が複製された Array
  • ys には が複製された Array

が補完されることにより、それぞれに対して同じ大きさのメッシュ構造が作成されるわけです。

今回は 2次元でしたが、3次元以上の構造も作成することができます。

meshgrid

こっちは 1次元のArray (list) を可変長引数として受けとる関数です。

meshgrid 関数は 行, 列 の順番が mgrid とは 逆っぽいです。 (mgrid が逆なのかな)

In [1]:
%autosave 0
%matplotlib inline
Autosave disabled
In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#from IPython.display import display
In [3]:
x = np.arange(1, 10, 3)
y = np.arange(10, 20, 2)
In [4]:
xs, ys = np.meshgrid(x, y)
In [5]:
pd.DataFrame(xs)
Out[5]:
0 1 2
0 1 4 7
1 1 4 7
2 1 4 7
3 1 4 7
4 1 4 7
In [6]:
pd.DataFrame(ys)
Out[6]:
0 1 2
0 10 10 10
1 12 12 12
2 14 14 14
3 16 16 16
4 18 18 18
In [7]:
# 構造は違うけど描画目的なら同じようになるので気にしなくていい
fig = plt.figure()
ax = Axes3D(fig)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.plot_surface(xs, ys, xs + ys)
Out[7]:
<mpl_toolkits.mplot3d.art3d.Poly3DCollection at 0x7f33e043f630>
In [8]:
# 引数の順番を逆にすると転置する
ys, xs = np.meshgrid(y, x)
In [9]:
pd.DataFrame(xs)
Out[9]:
0 1 2 3 4
0 1 1 1 1 1
1 4 4 4 4 4
2 7 7 7 7 7
In [10]:
pd.DataFrame(ys)
Out[10]:
0 1 2 3 4
0 10 12 14 16 18
1 10 12 14 16 18
2 10 12 14 16 18

自分はこっちのほうが直感的で好みです。

その他参考