つれづれなる備忘録

日々の発見をあるがままに綴る

Pythonによるデータ処理14 ~ カルマンフィルタ

今回はカルマンフィルタを用いたデータ処理について紹介する。

1. カルマンフィルタの概要

 カルマンフィルタは制御や経済、マーケティングなどのデータ分析に使われるフィルタ。

qiita.com

観測(測定)で得られた値から状態の値(真値)を推定するというもので、状態と観測は以下の方程式を仮定している。 状態は、ひとつ前の状態xt-1と比例定数Gの積に変動wが加わったもの、観測値は状態xtに比例定数Fの積に観測ノイズvが加わったものとみなす。このとき変動w, 観測ノイズvは平均0で分散がW, Vとする正規分布に従うものとする。

 \begin{eqnarray*}
状態方程式: x_t &=& G_t x_{t-1} + w_t  \\
観測方程式: y_t &=& F_t x_t +v_t 
\end{eqnarray*}

ただし、

 \begin{eqnarray*}
w_t &\sim& N(0,W_t)  \\
v_t &\sim& N(0,V_t)
\end{eqnarray*}

Nは正規分布で、Gt, Ft, Wt, Vtは既知(仮定)であるとする。

フィルタリング分布  p(x_t | y_{1:t} ) = N(m_t,C_t)

に対して観測値ytを用いて時刻tの状態xtの確率分布を生成する。ここで確率分布は正規分布Nに従い、平均値mt, 分散Ctとする。カルマンフィルタは観測値(データ)から正規分布の平均値mを求める作業になる。(求め方は上記サイト参照)

2. カルマンフィルタの実装

まずカルマンフィルタを適用する観測データを用意する。ここでは正弦波に分散1の正規分布ノイズを乗せたものを観測値dataとする。 フィルタリングの効果を後で確認するため、正弦波の部分をTrend: ytとした。

import numpy as np
import matplotlib.pyplot as plt

samplerate = 200
x = np.arange(0, 100) / samplerate    # 波形生成のための時間軸の作成
yt = 1.5*np.sin(10*2*np.pi*x)
data = np.random.normal(loc=0, scale=1, size=len(x)) + yt
plt.plot(x,data)

"観測データ"
観測データ

以下カルマンフィルタとカルマン平滑化の関数を以下のように定義する。(上記サイト参照)カルマン平滑化は、時刻tのカルマンフィルタリング分布p(xt | y1:t)を時刻t+1の平滑化分布p(xt+1 | y1:T)で補正するというもので、カルマンフィルタよりもさらに滑らかになる効果がある。 カルマンフィルタの出力としてはm, カルマン平滑化の出力はsを利用する。(カルマン平滑化ではカルマンフィルタの出力Cが必要)

def kalman_filter(m, C, y, G=G, F=F, W=W, V=V):
    """
    Kalman Filter
    m: 時点t-1のフィルタリング分布の平均
    C: 時点t-1のフィルタリング分布の分散共分散行列
    y: 時点tの観測値
    """
    a = G @ m
    R = G @ C @ G.T + W
    f = F @ a
    Q = F @ R @ F.T + V
    # 逆行列と何かの積を取る場合は、invよりsolveを使った方がいいらしい
    K = (np.linalg.solve(Q.T, F @ R.T)).T
    # K = R @ F.T @ np.linalg.inv(Q)
    m = a + K @ (y - f)
    C = R - K @ F @ R
    return m, C

def kalman_smoothing(s, S, m, C, G=G, W=W):
    """
    Kalman smoothing
    """
    # 1時点先予測分布のパラメータ計算
    a = G @ m
    R = G @ C @ G.T + W
    # 平滑化利得の計算
    # solveを使った方が約30%速くなる
    A = np.linalg.solve(R.T, G @ C.T).T
    # A = C @ G.T @ np.linalg.inv(R)
    # 状態の更新
    s = m + A @ (s - a)
    S = C + A @ (S - R) @ A.T
    return s, S

G,F,W,Vを与える必要があるので、適当な値を仮定して割り当て、またmとCの初期値も割り当てておく。

G = np.array([[1]])
F = np.array([[1]])
W = np.array([[1]]) # 恣意的に与える必要がある
V = np.array([[10]]) # 上に同じ
T = 100

m0 = np.array([[0]])
C0 = np.array([[1e7]])

# 結果を格納するarray
m = np.zeros((T, 1))
C = np.zeros((T, 1, 1))
s = np.zeros((T, 1))
S = np.zeros((T, 1, 1))

カルマンフィルタの関数: kalman_filterおよびカルマン平滑化の関数: kalman_smoothingを逐次適用してm,C,s,Sを計算する。

# カルマンフィルタ
for t in range(T):
    if t == 0:
        m[t], C[t] = kalman_filter(m0, C0, data[t:t+1])
    else:
        m[t], C[t] = kalman_filter(m[t-1:t], C[t-1:t], data[t:t+1])

# カルマン平滑化
for t in range(T):
    t = T - t - 1
    if t == T - 1:
        s[t] = m[t]
        S[t] = C[t]
    else:
        s[t], S[t] = kalman_smoothing(s[t+1], S[t+1], m[t], C[t])

3. カルマンフィルタ適用結果

カルマンフィルタ、カルマン平滑化を元の観測値(Original)、ノイズなしの正弦波(Trend)と比較した。カルマンフィルタは観測値の変動に対して順応しており、カルマン平滑化はカルマンフィルタよりもさらに滑らかになっていることがわかる。

plt.figure(figsize=(8,6))
plt.plot(x,data,'--')
plt.plot(x,yt,'k--')
plt.plot(x,m,'r')
plt.plot(x,s,'m')
plt.legend(['Original','Trend','Kalman filter','Kalman smoothing'])

"カルマンフィルタリング"
カルマンフィルタリング

単純に正弦波(Trend)を抽出するだけであれば過去に紹介したローパスやバンドパスフィルタが有用だが

atatat.hatenablog.com

イレギュラーな変動含めて信号としてとらえる場合はカルマンフィルタが有用だと思う。