Tool実装
「設計は固まった。次はToolの実装だ。」
田中VPoEがエディタを開く。
「各Toolは単一責任で実装する。テスト可能で、エージェントなしでも単体で動作することが重要だ。LangChainのTool規約に従って実装しよう。」
Tool実装の原則
Tool実装の原則:
1. 単一責任 → 1つのToolは1つのことだけをする
2. 型安全 → 入出力の型を明確にする
3. エラーハンドリング → 失敗時も有用な情報を返す
4. テスト可能 → エージェントなしで単体テストできる
5. ドキュメント → descriptionにLLMが理解できる説明を書く
Tool 1: データ取得ツール
from langchain_core.tools import tool
import pandas as pd
import joblib
# グローバルにデータとモデルをロード
DATA = pd.read_csv('WA_Fn-UseC_-Telco-Customer-Churn.csv')
MODEL = joblib.load('churn_model.pkl')
SCALER = joblib.load('scaler.pkl')
@tool
def get_customer_data(customer_id: str) -> dict:
"""顧客IDからデータを取得する。顧客の属性情報を辞書形式で返す。"""
row = DATA[DATA['customerID'] == customer_id]
if row.empty:
return {"error": f"顧客ID '{customer_id}' が見つかりません"}
return row.iloc[0].to_dict()
Tool 2: 前処理ツール
import numpy as np
from sklearn.preprocessing import LabelEncoder
@tool
def preprocess_customer(raw_data: dict) -> dict:
"""生の顧客データを前処理してモデル入力形式に変換する。"""
try:
df = pd.DataFrame([raw_data])
# TotalChargesの型変換
df['TotalCharges'] = pd.to_numeric(df['TotalCharges'], errors='coerce').fillna(0)
# customerIDの削除
if 'customerID' in df.columns:
df = df.drop('customerID', axis=1)
# Churnの削除(存在する場合)
if 'Churn' in df.columns:
df = df.drop('Churn', axis=1)
# エンコーディング
binary_cols = ['gender', 'Partner', 'Dependents', 'PhoneService', 'PaperlessBilling']
le = LabelEncoder()
for col in binary_cols:
if col in df.columns:
df[col] = le.fit_transform(df[col])
service_cols = ['MultipleLines', 'OnlineSecurity', 'OnlineBackup',
'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies']
for col in service_cols:
if col in df.columns:
df[col] = df[col].replace({'No internet service': 'No', 'No phone service': 'No'})
multi_cols = service_cols + ['InternetService', 'Contract', 'PaymentMethod']
df = pd.get_dummies(df, columns=[c for c in multi_cols if c in df.columns], drop_first=True)
# 特徴量エンジニアリング
if 'tenure' in df.columns:
df['is_new_customer'] = (df['tenure'] <= 6).astype(int)
df['is_loyal'] = (df['tenure'] >= 48).astype(int)
# スケーリング
numeric_cols = ['tenure', 'MonthlyCharges', 'TotalCharges']
existing_numeric = [c for c in numeric_cols if c in df.columns]
df[existing_numeric] = SCALER.transform(df[existing_numeric])
return {
"features": df.values.tolist()[0],
"feature_names": df.columns.tolist(),
"success": True,
}
except Exception as e:
return {"error": str(e), "success": False}
Tool 3: 離反予測ツール
@tool
def predict_churn(features: list) -> dict:
"""前処理済みの特徴量から離反確率を予測する。"""
try:
features_array = np.array(features).reshape(1, -1)
probability = MODEL.predict_proba(features_array)[0][1]
# リスクレベルの判定
if probability >= 0.7:
risk_level = "HIGH"
elif probability >= 0.4:
risk_level = "MEDIUM"
else:
risk_level = "LOW"
return {
"churn_probability": round(float(probability), 4),
"risk_level": risk_level,
"success": True,
}
except Exception as e:
return {"error": str(e), "success": False}
Tool 4: SHAP分析ツール
import shap
@tool
def explain_churn_prediction(features: list, feature_names: list) -> dict:
"""SHAP値を使って離反予測の要因を説明する。"""
try:
features_array = np.array(features).reshape(1, -1)
# SHAP Explainerの作成
explainer = shap.TreeExplainer(MODEL)
shap_values = explainer.shap_values(features_array)
# クラス1(離反)のSHAP値
if isinstance(shap_values, list):
shap_vals = shap_values[1][0]
else:
shap_vals = shap_values[0]
# 上位要因の抽出
importance = sorted(
zip(feature_names, shap_vals, features_array[0]),
key=lambda x: abs(x[1]),
reverse=True
)[:5]
explanations = []
for feat_name, shap_val, feat_val in importance:
direction = "離反促進" if shap_val > 0 else "離反抑制"
explanations.append({
"feature": feat_name,
"shap_value": round(float(shap_val), 4),
"feature_value": round(float(feat_val), 4),
"direction": direction,
})
return {
"explanations": explanations,
"base_value": round(float(explainer.expected_value[1]
if isinstance(explainer.expected_value, list)
else explainer.expected_value), 4),
"success": True,
}
except Exception as e:
return {"error": str(e), "success": False}
Tool 5: リテンション施策提案ツール
@tool
def suggest_retention_actions(
shap_explanations: list,
customer_data: dict
) -> dict:
"""SHAP分析結果と顧客データからリテンション施策を提案する。"""
actions = []
for exp in shap_explanations:
if exp["direction"] != "離反促進":
continue
feature = exp["feature"]
# 特徴量に応じた施策マッピング
action_map = {
"Contract_Month-to-month": {
"action": "年間契約への移行を提案(初月20%割引)",
"priority": "HIGH",
"expected_impact": "離反率を約30%低減",
},
"tenure": {
"action": "オンボーディングプログラムの強化",
"priority": "HIGH",
"expected_impact": "初期離反を25%低減",
},
"InternetService_Fiber optic": {
"action": "回線品質チェックと速度保証の提供",
"priority": "MEDIUM",
"expected_impact": "Fiber optic離反を15%低減",
},
"MonthlyCharges": {
"action": "プラン見直しの提案(適正プランへのダウングレード)",
"priority": "MEDIUM",
"expected_impact": "料金不満による離反を20%低減",
},
"OnlineSecurity_No": {
"action": "オンラインセキュリティ無料トライアルの提供",
"priority": "MEDIUM",
"expected_impact": "サービス定着率を10%向上",
},
"PaymentMethod_Electronic check": {
"action": "自動引き落としへの変更インセンティブ",
"priority": "LOW",
"expected_impact": "支払い関連離反を10%低減",
},
}
for key, action in action_map.items():
if key in feature:
actions.append(action)
break
if not actions:
actions.append({
"action": "定期的なフォローアップコールの実施",
"priority": "MEDIUM",
"expected_impact": "全般的な離反率を5%低減",
})
return {"actions": actions, "success": True}
Tool 6: 可視化ツール
import matplotlib.pyplot as plt
@tool
def visualize_churn_analysis(
churn_probability: float,
risk_level: str,
explanations: list
) -> dict:
"""離反分析結果をグラフで可視化する。"""
try:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# リスクゲージ
colors = {'LOW': '#2ecc71', 'MEDIUM': '#f39c12', 'HIGH': '#e74c3c'}
axes[0].barh(['離反確率'], [churn_probability], color=colors[risk_level])
axes[0].set_xlim(0, 1)
axes[0].set_title(f'離反リスク: {risk_level}')
axes[0].axvline(x=0.4, color='orange', linestyle='--', alpha=0.5)
axes[0].axvline(x=0.7, color='red', linestyle='--', alpha=0.5)
# SHAP要因
features = [e['feature'][:20] for e in explanations]
values = [e['shap_value'] for e in explanations]
bar_colors = ['#e74c3c' if v > 0 else '#2ecc71' for v in values]
axes[1].barh(features, values, color=bar_colors)
axes[1].set_title('離反要因(SHAP値)')
axes[1].set_xlabel('SHAP値')
plt.tight_layout()
path = 'churn_analysis_result.png'
plt.savefig(path, dpi=150)
plt.close()
return {"image_path": path, "success": True}
except Exception as e:
return {"error": str(e), "success": False}
まとめ
| Tool | 役割 | 入力 | 出力 |
|---|---|---|---|
| get_customer_data | データ取得 | customer_id | 顧客属性dict |
| preprocess_customer | 前処理 | raw_data | 特徴量ベクトル |
| predict_churn | 離反予測 | features | 確率, リスクレベル |
| explain_churn_prediction | SHAP分析 | features | 要因リスト |
| suggest_retention_actions | 施策提案 | SHAP結果 | アクションリスト |
| visualize_churn_analysis | 可視化 | 分析結果 | 画像パス |
チェックリスト
- LangChainの@toolデコレータを使ってToolを定義できる
- 各Toolの入出力の型を明確に定義できる
- エラーハンドリングを実装できる
- SHAP分析をToolとして実装できる
- 施策提案のルールベースマッピングを作成できる
次のステップへ
各Toolが実装できた。次はLangGraphでこれらのToolを接続し、ワークフローとして動作させよう。
推定読了時間: 30分