LESSON

演習:離反分析AIエージェントを実装しよう

「ここまでの知識を総動員して、実際に動くAIエージェントを作り上げてくれ。」

田中VPoEが真剣な眼差しで言う。

「営業チームが月曜の朝に『先週の高リスク顧客を教えて』と聞いたら、即座に回答できるエージェントだ。Toolの実装からLangGraphの構築まで、一気にやろう。」

ミッション概要

LangGraphを使った離反分析AIエージェントを設計から実装まで一貫して構築する。完成したエージェントは、自然言語の質問に対して離反予測・要因分析・施策提案を返せること。


Mission 1: Tool群の実装(40分)

以下の5つのToolを実装せよ。各Toolは単体テストで動作確認すること。

  1. get_customer_data: 顧客IDからデータを取得
  2. preprocess_customer: 生データを前処理
  3. predict_churn: 離反確率を予測
  4. explain_prediction: SHAP分析で要因を説明
  5. suggest_actions: リテンション施策を提案
解答例
import pandas as pd
import numpy as np
import joblib
import shap
from langchain_core.tools import tool

# モデルとデータの準備
DATA = pd.read_csv('WA_Fn-UseC_-Telco-Customer-Churn.csv')
MODEL = joblib.load('churn_model.pkl')
SCALER = joblib.load('scaler.pkl')

@tool
def get_customer_data(customer_id: str) -> dict:
    """顧客IDから属性データを取得する"""
    row = DATA[DATA['customerID'] == customer_id]
    if row.empty:
        return {"error": f"ID '{customer_id}' not found", "success": False}
    return {"data": row.iloc[0].to_dict(), "success": True}

@tool
def preprocess_customer(raw_data: dict) -> dict:
    """生の顧客データをモデル入力形式に前処理する"""
    try:
        df = pd.DataFrame([raw_data])
        df['TotalCharges'] = pd.to_numeric(df['TotalCharges'], errors='coerce').fillna(0)
        for col in ['customerID', 'Churn']:
            if col in df.columns:
                df = df.drop(col, axis=1)

        from sklearn.preprocessing import LabelEncoder
        le = LabelEncoder()
        for col in ['gender', 'Partner', 'Dependents', 'PhoneService', 'PaperlessBilling']:
            if col in df.columns:
                df[col] = le.fit_transform(df[col])

        svc_cols = ['MultipleLines', 'OnlineSecurity', 'OnlineBackup',
                    'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies']
        for col in svc_cols:
            if col in df.columns:
                df[col] = df[col].replace({'No internet service': 'No', 'No phone service': 'No'})

        multi = svc_cols + ['InternetService', 'Contract', 'PaymentMethod']
        df = pd.get_dummies(df, columns=[c for c in multi if c in df.columns], drop_first=True)

        numeric = ['tenure', 'MonthlyCharges', 'TotalCharges']
        existing = [c for c in numeric if c in df.columns]
        df[existing] = SCALER.transform(df[existing])

        return {"features": df.values.tolist()[0], "names": df.columns.tolist(), "success": True}
    except Exception as e:
        return {"error": str(e), "success": False}

@tool
def predict_churn(features: list) -> dict:
    """離反確率を予測しリスクレベルを判定する"""
    try:
        arr = np.array(features).reshape(1, -1)
        prob = float(MODEL.predict_proba(arr)[0][1])
        level = "HIGH" if prob >= 0.7 else "MEDIUM" if prob >= 0.4 else "LOW"
        return {"probability": round(prob, 4), "risk_level": level, "success": True}
    except Exception as e:
        return {"error": str(e), "success": False}

@tool
def explain_prediction(features: list, feature_names: list) -> dict:
    """SHAP値で予測の要因を説明する"""
    try:
        arr = np.array(features).reshape(1, -1)
        explainer = shap.TreeExplainer(MODEL)
        sv = explainer.shap_values(arr)
        vals = sv[1][0] if isinstance(sv, list) else sv[0]
        factors = sorted(zip(feature_names, vals), key=lambda x: abs(x[1]), reverse=True)[:5]
        return {"factors": [{"feature": f, "shap": round(float(s), 4),
                "direction": "離反促進" if s > 0 else "離反抑制"} for f, s in factors],
                "success": True}
    except Exception as e:
        return {"error": str(e), "success": False}

@tool
def suggest_actions(factors: list, customer_data: dict) -> dict:
    """離反要因に基づきリテンション施策を提案する"""
    actions = []
    mapping = {
        "Contract": "年間契約への移行提案(初月20%OFF)",
        "tenure": "オンボーディング強化プログラム",
        "MonthlyCharges": "プラン最適化相談の案内",
        "InternetService": "回線品質チェックの実施",
        "OnlineSecurity": "セキュリティ3ヶ月無料トライアル",
        "TechSupport": "テックサポート優先対応",
        "PaymentMethod": "自動引落し変更で月¥500割引",
    }
    for f in factors:
        if f["direction"] == "離反促進":
            for key, action in mapping.items():
                if key in f["feature"]:
                    actions.append({"action": action, "trigger": f["feature"]})
                    break
    return {"actions": actions or [{"action": "フォローアップコール実施", "trigger": "general"}], "success": True}

# 単体テスト
print("=== Tool単体テスト ===")
res1 = get_customer_data.invoke("7590-VHVEG")
print(f"get_customer_data: {res1['success']}")

if res1["success"]:
    res2 = preprocess_customer.invoke(res1["data"])
    print(f"preprocess: {res2['success']}, features={len(res2['features'])}")

    res3 = predict_churn.invoke(res2["features"])
    print(f"predict: prob={res3['probability']}, risk={res3['risk_level']}")

    res4 = explain_prediction.invoke({"features": res2["features"], "feature_names": res2["names"]})
    print(f"explain: {len(res4['factors'])} factors")

    res5 = suggest_actions.invoke({"factors": res4["factors"], "customer_data": res1["data"]})
    print(f"suggest: {len(res5['actions'])} actions")

Mission 2: LangGraphワークフローの構築(40分)

以下の要件でLangGraphワークフローを構築せよ。

  1. State定義(全ノード間で共有するデータ構造)
  2. 8つのノード(意図分類/データ取得/前処理/予測/SHAP/施策/応答/エラー)
  3. 2つの条件分岐(意図分類後、予測後のリスクレベル分岐)
  4. エラーハンドリング(各ノードのエラーをキャッチ)
  5. コンパイルして実行テスト
解答例
from typing import TypedDict, Optional, Annotated
from operator import add
from langgraph.graph import StateGraph, END
from langchain_openai import ChatOpenAI
import re

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

class State(TypedDict):
    query: str
    query_type: Optional[str]
    customer_id: Optional[str]
    raw_data: Optional[dict]
    features: Optional[list]
    feature_names: Optional[list]
    probability: Optional[float]
    risk_level: Optional[str]
    factors: Optional[list]
    actions: Optional[list]
    response: Optional[str]
    error: Optional[str]

def classify(state):
    q = state["query"]
    id_match = re.search(r'\d{4}-[A-Z]{5}', q)
    return {"query_type": "individual" if id_match else "general",
            "customer_id": id_match.group() if id_match else None}

def fetch(state):
    r = get_customer_data.invoke(state["customer_id"])
    return {"raw_data": r.get("data"), "error": r.get("error") if not r.get("success") else None}

def preprocess_node(state):
    r = preprocess_customer.invoke(state["raw_data"])
    return {"features": r.get("features"), "feature_names": r.get("names"),
            "error": r.get("error") if not r.get("success") else None}

def predict_node(state):
    r = predict_churn.invoke(state["features"])
    return {"probability": r.get("probability"), "risk_level": r.get("risk_level"),
            "error": r.get("error") if not r.get("success") else None}

def explain_node(state):
    r = explain_prediction.invoke({"features": state["features"], "feature_names": state["feature_names"]})
    return {"factors": r.get("factors"), "error": r.get("error") if not r.get("success") else None}

def recommend_node(state):
    r = suggest_actions.invoke({"factors": state["factors"], "customer_data": state["raw_data"]})
    return {"actions": r.get("actions")}

def respond_node(state):
    ctx = {k: state.get(k) for k in ["customer_id","probability","risk_level","factors","actions"]}
    resp = llm.invoke(f"以下の離反分析結果を日本語でレポートにまとめてください:\n{ctx}")
    return {"response": resp.content}

def error_node(state):
    return {"response": f"エラー: {state.get('error', '不明')}"}

# グラフ構築
g = StateGraph(State)
for name, fn in [("classify",classify),("fetch",fetch),("preprocess",preprocess_node),
                  ("predict",predict_node),("explain",explain_node),("recommend",recommend_node),
                  ("respond",respond_node),("error",error_node)]:
    g.add_node(name, fn)

g.set_entry_point("classify")
g.add_conditional_edges("classify",
    lambda s: "fetch" if s["query_type"]=="individual" and not s.get("error") else ("error" if s.get("error") else "respond"),
    {"fetch":"fetch","respond":"respond","error":"error"})
g.add_edge("fetch", "preprocess")
g.add_edge("preprocess", "predict")
g.add_conditional_edges("predict",
    lambda s: "error" if s.get("error") else ("explain" if s.get("risk_level") in ["HIGH","MEDIUM"] else "respond"),
    {"explain":"explain","respond":"respond","error":"error"})
g.add_edge("explain", "recommend")
g.add_edge("recommend", "respond")
g.add_edge("respond", END)
g.add_edge("error", END)

app = g.compile()

# テスト
result = app.invoke({"query": "顧客 7590-VHVEG の離反リスクは?"})
print(result["response"])

Mission 3: 動作検証とエッジケース対応(40分)

構築したエージェントを以下のシナリオでテストし、結果を記録せよ。

  1. 正常系: 存在する顧客IDで離反リスクを問い合わせ
  2. 異常系: 存在しない顧客IDで問い合わせ
  3. 曖昧系: 顧客IDなしで「離反率が高い顧客の特徴は?」と質問
  4. エッジケース: 離反リスクLOWの顧客(SHAP分析がスキップされるパス)
  5. 各テストの入出力を記録し、改善点を3つ以上挙げる
解答例
# テストケース
test_cases = [
    {
        "name": "正常系 - 高リスク顧客",
        "query": "顧客 7590-VHVEG の離反リスクを分析してください",
        "expected": "離反確率とSHAP分析が含まれた回答",
    },
    {
        "name": "異常系 - 存在しないID",
        "query": "顧客 9999-XXXXX の離反リスクは?",
        "expected": "エラーメッセージ",
    },
    {
        "name": "曖昧系 - IDなし",
        "query": "離反率が高い顧客の特徴を教えて",
        "expected": "一般的な回答",
    },
    {
        "name": "エッジケース - 低リスク顧客",
        "query": "顧客 6388-TABGU の離反リスクは?",
        "expected": "低リスクの回答(SHAP分析なし)",
    },
]

results = []
for tc in test_cases:
    try:
        result = app.invoke({"query": tc["query"]})
        results.append({
            "テスト": tc["name"],
            "入力": tc["query"],
            "結果": result.get("response", "")[:100],
            "ステータス": "PASS" if result.get("response") else "FAIL",
        })
    except Exception as e:
        results.append({
            "テスト": tc["name"],
            "入力": tc["query"],
            "結果": str(e)[:100],
            "ステータス": "FAIL",
        })

print(pd.DataFrame(results).to_string(index=False))

# 改善点
improvements = [
    "1. セグメント分析(複数顧客の一括分析)機能の追加",
    "2. 回答にSHAP可視化画像を添付する機能",
    "3. 施策の優先度スコアリング(インパクト×実行容易性)",
    "4. 過去の分析結果のキャッシュ機能",
    "5. 対話形式での追加質問対応(フォローアップ)",
]
for imp in improvements:
    print(imp)

達成度チェック

  • 5つのToolを実装し単体テストで動作確認した
  • LangGraphで8ノードのワークフローを構築した
  • 2つの条件分岐(意図分類、リスクレベル)を実装した
  • エラーハンドリングが正しく動作する
  • 4つ以上のテストケースで動作検証した
  • 改善点を3つ以上特定した

推定所要時間: 120分