前回、OCI で AI チャットを作りました。

あれはあれで、良いんですが、もうちょっと頑張って 履歴保存が出来るようにしてみました。

履歴保存するには、ユーザーを管理しないと出来ないんで、googleさんに認証してもらうことにしました。

別に、streamlit が対応してりゃ、他でも出来ると思います。

但し、ユーザーを一意識別する方法は、調べないとダメですよ。

ということで、なんちゃって 履歴保存付き AI チャットです。

履歴保存には、OCIの NoSQL を使いました。

しかしこれ、NoSQLなのに、がっつり SQL使ってるんですけど、どういうこと。

まー、良いんですけど、Deleteとかって、何故か一括で削除出来ないとか、仕様がよくわからんです。

ということで、これが完成ソースだー

table_limits は、費用バランスを考慮して修正してね。


chatdb.py

import oci

import datetime
import pytz

class chatdb:

    tablename = "ChatHistory2"
    childtablename = "messages"
    nosqlcl = None
    compartmentid = None

    def __init__(self, config, compid ):
        self.nosqlcl : oci.nosql.nosql_client.NosqlClient = oci.nosql.nosql_client.NosqlClient(config)
        self.compartmentid = compid

    # 索引作成
    def _create_table_index(self ):
        print(f"NoSQL インデックス 'message' を確認・作成中...")
        try:
            index_name = "message"
            create_index_details = oci.nosql.models.CreateIndexDetails(
                name=index_name,
                compartment_id=self.compartmentid,
                keys=[
                    oci.nosql.models.IndexKey(column_name="user_id"),
                    oci.nosql.models.IndexKey(column_name="message_timestamp")
                ],
                is_if_not_exists = True
            )
            response: oci.response.Response = self.nosqlcl.create_index(table_name_or_id=self.tablename ,create_index_details=create_index_details)

            oci.wait_until(
                self.nosqlcl,
                self.nosqlcl.get_index(table_name_or_id=self.tablename, index_name=index_name, compartment_id=self.compartmentid),
                'lifecycle_state',
                'ACTIVE'
            )
            print(f"インデックス '{index_name}' が ACTIVE 状態になりました。")

        except oci.exceptions.ServiceError as e:
            if e.code == 'IndexAlreadyExists':
                print(f"インデックス '{index_name}' は既に存在します。")
            else:
                print(f"インデックス作成中にエラーが発生しました: {e}")
                print(f"詳細エラーメッセージ: {e.message}")
        except Exception as e:
            print(f"予期せぬエラー: {e}")

    # 親テーブル作成
    def _create_table(self ):
        print(f"NoSQL テーブル '{self.tablename}' を確認・作成中...")
        try:

            try:
                print(f"NoSQL テーブル '{self.tablename}' を確認中...")

                # テーブルが既に存在するか確認
                response: oci.response.Response = self.nosqlcl.get_table(table_name_or_id=self.tablename, compartment_id=self.compartmentid)
                tbl : oci.nosql.models.Table = response.data
                if tbl.lifecycle_state == oci.nosql.models.Table.LIFECYCLE_STATE_ACTIVE:
                    print(f"テーブル '{self.tablename}' は既に存在します。")
                    return
            except oci.exceptions.ServiceError as e:
                if e.message.startswith('Table not found') == False:
                    print(f"テーブル確認中にエラーが発生しました: {e.message}")
                    return

            print(f"NoSQL テーブル '{self.tablename}' を作成中...")

            # TableLimitsを定義 (プロビジョニング容量の例)
            table_limits = oci.nosql.models.TableLimits(
                max_read_units=8,
                max_write_units=8,
                max_storage_in_g_bs=1,
                capacity_mode = oci.nosql.models.TableLimits.CAPACITY_MODE_PROVISIONED
            )        

            ddl_statement = f"""
            CREATE TABLE {self.tablename} (
                user_id STRING,
                session_id STRING,
                message_timestamp TIMESTAMP(3),
                title STRING,
                PRIMARY KEY (SHARD(user_id), session_id)
            ) USING TTL 90 days
            """

            create_table_details = oci.nosql.models.CreateTableDetails(
                name=self.tablename,
                compartment_id=self.compartmentid,
                ddl_statement=ddl_statement,
                table_limits=table_limits
            )

            response: oci.response.Response = self.nosqlcl.create_table(create_table_details)
            print(f"テーブル作成リクエスト送信済み。ワークリクエストID: {response.request_id}")

            oci.wait_until(
                self.nosqlcl,
                self.nosqlcl.get_table(table_name_or_id=self.tablename, compartment_id=self.compartmentid),
                'lifecycle_state',
                'ACTIVE'
            )
            print(f"テーブル '{self.tablename}' が ACTIVE 状態になりました。")

        except oci.exceptions.ServiceError as e:
            if e.code == 'TableAlreadyExists':
                print(f"テーブル '{self.tablename}' は既に存在します。")
            else:
                print(f"テーブル作成中にエラーが発生しました: {e}")
                print(f"詳細エラーメッセージ: {e.message}")
        except Exception as e:
            print(f"予期せぬエラー: {e}")

    # 子テーブル作成
    def _create_child_table( self):
        fulltablename = self.tablename + "." + self.childtablename
        print(f"NoSQL 子テーブル '{fulltablename}' を確認・作成中...")
        try:

            try:
                print(f"NoSQL 子テーブル '{fulltablename}' を確認中...")

                # テーブルが既に存在するか確認
                response: oci.response.Response = self.nosqlcl.get_table(table_name_or_id=fulltablename, compartment_id=self.compartmentid)
                tbl : oci.nosql.models.Table = response.data
                if tbl.lifecycle_state == oci.nosql.models.Table.LIFECYCLE_STATE_ACTIVE:
                    print(f"子テーブル '{fulltablename}' は既に存在します。")
                    return
            except oci.exceptions.ServiceError as e:
                if e.message.startswith('Table not found') == False:
                    print(f"テーブル確認中にエラーが発生しました: {e.message}")
                    return

            print(f"NoSQL 子テーブル '{fulltablename}' を作成中...")

            ddl_statement = f"""
            CREATE TABLE {fulltablename} (
                message_timestamp TIMESTAMP(3),
                role STRING,
                message STRING,
                PRIMARY KEY (message_timestamp)
            ) USING TTL 90 days
            """

            create_table_details = oci.nosql.models.CreateTableDetails(
                name=fulltablename,
                compartment_id=self.compartmentid,
                ddl_statement=ddl_statement
            )

            response: oci.response.Response = self.nosqlcl.create_table(create_table_details)
            print(f"子テーブル作成リクエスト送信済み。ワークリクエストID: {response.request_id}")

            oci.wait_until(
                self.nosqlcl,
                self.nosqlcl.get_table(table_name_or_id=fulltablename, compartment_id=self.compartmentid),
                'lifecycle_state',
                'ACTIVE'
            )
            print(f"子テーブル '{fulltablename}' が ACTIVE 状態になりました。")

        except oci.exceptions.ServiceError as e:
            if e.code == 'TableAlreadyExists':
                print(f"子テーブル '{fulltablename}' は既に存在します。")
            else:
                print(f"子テーブル作成中にエラーが発生しました: {e}")
                print(f"詳細エラーメッセージ: {e.message}")
        except Exception as e:
            print(f"予期せぬエラー: {e}")


    # チャット履歴保存
    def save_chat_message(self, user_id: str, session_id: str, role: str, message: str, title: str):

        fulltablename = self.tablename + "." + self.childtablename
        timestamp = datetime.datetime.now(pytz.utc).strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'

        # 親テーブル記録
        try:

            row_data = {
                'user_id': user_id,
                'session_id': session_id,
                'message_timestamp': timestamp,
                'title': title
            }
            put_row_details = oci.nosql.models.UpdateRowDetails(
                compartment_id=self.compartmentid,
                value=row_data
            )
            self.nosqlcl.update_row(
                table_name_or_id=self.tablename,
                update_row_details=put_row_details
            )
        except Exception as e:
            print(f"NoSQL: メッセージ保存中にエラーが発生しました: {e}")

        # メッセージ記録
        try:

            row_data = {
                'user_id': user_id,
                'session_id': session_id,
                'message_timestamp': timestamp,
                'role': role,
                'message': message
            }
            put_row_details = oci.nosql.models.UpdateRowDetails(
                compartment_id=self.compartmentid,
                value=row_data
            )
            self.nosqlcl.update_row(
                table_name_or_id=fulltablename,
                update_row_details=put_row_details
            )
        except Exception as e:
            print(f"NoSQL: メッセージ保存中にエラーが発生しました: {e}")

    # チャット履歴ロード (特定のセッションIDを指定)
    def load_chat_history_for_session(self, user_id: str, session_id: str):

        fulltablename = self.tablename + "." + self.childtablename

        print(f"NoSQL: ユーザー {user_id}, セッション {session_id} の履歴を読み込み中...")
        history = []
        try:
            query_statement = f"""
            SELECT role, message, message_timestamp
            FROM {fulltablename}
            WHERE user_id = '{user_id}' AND session_id = '{session_id}'
            ORDER BY user_id ASC, session_id ASC, message_timestamp ASC
            """
            query_details = oci.nosql.models.QueryDetails(
                compartment_id=self.compartmentid,
                statement=query_statement
            )
            response: oci.response.Response = self.nosqlcl.query(query_details)
            result: oci.nosql.models.QueryResultCollection = response.data

            for item in result.items:
                st_role = "assistant" if item['role'] == "CHATBOT" else "user"
                history.append({"role": st_role, "message": item['message']})
            print(f"NoSQL: {len(history)} 件の履歴を読み込みました。")
        except Exception as e:
            print(f"NoSQL: 履歴読み込み中にエラーが発生しました: {e}")
        return history

    # ユーザーの指定セッションを削除する
    def delete_user_session(self, user_id: str, session_id: str):

        fulltablename = self.tablename + "." + self.childtablename

        print(f"NoSQL: ユーザー {user_id}, セッション {session_id} を削除中...")

        # メッセージ削除
        try:
            query_statement = f"""
            SELECT role, message, message_timestamp
            FROM {fulltablename}
            WHERE user_id = '{user_id}' AND session_id = '{session_id}'
            """
            query_details = oci.nosql.models.QueryDetails(
                compartment_id=self.compartmentid,
                statement=query_statement
            )
            response: oci.response.Response = self.nosqlcl.query(query_details)
            result: oci.nosql.models.QueryResultCollection = response.data

            for item in result.items:
                response: oci.response.Response = self.nosqlcl.delete_row(
                    compartment_id=self.compartmentid,
                    table_name_or_id=fulltablename,
                    key=[f"user_id:{user_id}",f"session_id:{session_id}", f"message_timestamp:{item['message_timestamp']}"]
                    )
                drr:oci.nosql.models.DeleteRowResult = response.data
            print(f"NoSQL: セッション {session_id} の削除が完了しました。")
        except Exception as e:
            print(f"NoSQL: セッション削除中にエラーが発生しました: {e}")

        # 親テーブル削除
        try:
            query_statement = f"""
            SELECT title
            FROM {self.tablename}
            WHERE user_id = '{user_id}' AND session_id = '{session_id}'
            """
            query_details = oci.nosql.models.QueryDetails(
                compartment_id=self.compartmentid,
                statement=query_statement
            )
            response: oci.response.Response = self.nosqlcl.query(query_details)
            result: oci.nosql.models.QueryResultCollection = response.data

            for item in result.items:
                response: oci.response.Response = self.nosqlcl.delete_row(
                    compartment_id=self.compartmentid,
                    table_name_or_id=self.tablename,
                    key=[f"user_id:{user_id}",f"session_id:{session_id}"]
                    )
                drr:oci.nosql.models.DeleteRowResult = response.data
            print(f"NoSQL: セッション {session_id} の削除が完了しました。")
        except Exception as e:
            print(f"NoSQL: セッション削除中にエラーが発生しました: {e}")

    # ユーザーの全セッションIDを取得する新しい関数
    def get_user_session_ids(self, user_id: str):

        print(f"NoSQL: ユーザー {user_id} のセッションIDを検索中...")
        session_ids = []
        try:
            # DISTINCTキーワードを使用して、重複しないsession_idを取得
            query_statement = f"""
            SELECT session_id, message_timestamp, title
            FROM {self.tablename}
            WHERE user_id = '{user_id}'
            ORDER BY user_id DESC, message_timestamp DESC
            """
            query_details = oci.nosql.models.QueryDetails(
                compartment_id=self.compartmentid,
                statement=query_statement
            )
            response: oci.response.Response = self.nosqlcl.query(query_details)
            result: oci.nosql.models.QueryResultCollection = response.data

            for item in result.items:
                s_id = item['session_id']
                message_timestamp = item['message_timestamp']
                title = item['title']
                session_ids.append( [s_id, message_timestamp, title] )

            print(f"NoSQL: {len(session_ids)} 件のセッションIDを検出しました。")
        except Exception as e:
            print(f"NoSQL: セッションID取得中にエラーが発生しました: {e}")
        return session_ids

    def createtable( self ) :
        self._create_table()
        self._create_child_table()
        self._create_table_index()

app.py

#

import streamlit as st
import oci
from oci.generative_ai import GenerativeAiClient
from oci.generative_ai_inference import GenerativeAiInferenceClient
from oci.generative_ai_inference.models import (
    ChatDetails,
    CohereChatRequest,
    OnDemandServingMode
)
import uuid
import datetime
import pytz
import hashlib

from chatdb import chatdb

# テーマの取得
theme = "dark" if st.config.get_option("theme.base") == "dark" else "light"

# モバイル表示の問題を修正
# テーマに応じたCSSを適用
st.markdown(f"""
<style>
@media (max-width: 800px) {{
    .stChatInput {{
        position: fixed;
        bottom: 0;
        left: 0;
        width: 100%;
        padding: 10px;
        z-index: 1000;
        transition: background-color 0.3s ease;
    }}
    .stChatInput textarea {{
        width: 100%;
        box-sizing: border-box;
    }}

    /* Lightモードのスタイル */
    .{theme}-mode .stChatInput {{
        background-color: #ffffff;
        border-top: 1px solid #ccc;
        box-shadow: 0 -2px 5px rgba(0, 0, 0, 0.1);
    }}

    /* Darkモードのスタイル */
    .dark-mode .stChatInput {{
        background-color: #1e1e1e;
        border-top: 1px solid #444;
        box-shadow: 0 -2px 5px rgba(0, 0, 0, 0.5);
    }}
    .dark-mode .stChatInput textarea {{
        color: #ffffff;
        background-color: #1e1e1e;
    }}
}}
</style>
""", unsafe_allow_html=True)

# OCI認証情報
config = oci.config.from_file("~/.oci/config", "DEFAULT")
COMPARTMENT_ID = "ocid1.compartment.oc1..aaaaaaaamgmw22hogecwqnunirb3urhoger4ihdgoilkdjkv2sabokaq5svc"

# Google認証
LOGINBTN = "Googleでログイン"
AUTHSECTION = "google"
# 認証ID識別子
def AUTHID(user) :
    return user.get("sub")
# 許可確認する
def isContain(oid) :
    return True

# チャットDB
db = chatdb(config,COMPARTMENT_ID)

# Generative AI クライアントの初期化
DEFAULT_MODEL = "cohere.command-latest"
client = GenerativeAiInferenceClient(config=config)
generative_ai_client = GenerativeAiClient(config)

# 日本タイムゾーン
jst_timezone = pytz.timezone('Asia/Tokyo')

#日時変換->JST
def parseDateTime( tm ) :
    return datetime.datetime.fromisoformat(tm.replace('Z', '+00:00')).astimezone(jst_timezone)

# セッションID生成
def generate_unique_session_id() -> str:
    return hashlib.md5(str(uuid.uuid4()).encode('utf-8')).hexdigest()[:12]

#モデル一覧
available_models = []
ret:oci.response.Response = generative_ai_client.list_models( compartment_id=COMPARTMENT_ID)
models:oci.generative_ai.models.ModelCollection = ret.data
model:oci.generative_ai.models.Model
for model in models.items:
    if( "CHAT" in model.capabilities ):
        available_models.append(model.display_name)

#タイトル
st.title("OCI AI Chat")

if not st.user.is_logged_in:
    st.title("ログインしてください")
    if st.button(LOGINBTN):
        st.login(AUTHSECTION)
        st.stop()
else:
    oid = AUTHID(st.user)
    
    if 'nosql_table_checked' not in st.session_state:
        db.createtable()
        st.session_state.nosql_table_checked = True

    if 'current_chat_session_id' not in st.session_state:
        st.session_state.current_chat_session_id = None

    if 'messages_loaded_for_session' not in st.session_state:
        st.session_state.messages_loaded_for_session = None

    # ログアウトボタン
    if st.button("ログアウト"):
        st.logout()

    # 利用可能権限チェック
    if( isContain(oid) == False ) :
        st.write("許可されていません")
    else :
        st.sidebar.header(f"Login: {st.user.name}")
        selected_model = st.sidebar.selectbox("使用するモデルを選択", available_models, index=available_models.index(DEFAULT_MODEL) if DEFAULT_MODEL in available_models else 0)

        # 過去履歴構築
        # ユーザーの全セッションIDを取得
        all_session_ids: list = db.get_user_session_ids(oid)
        # 新しいセッションを開始するためのオプションを追加
        NEWCHAT = "新しいチャットを開始"
        # 全セッション追加
        options = []
        options.append( ["-1", NEWCHAT, NEWCHAT] )
        for item in all_session_ids :
            session_id = item[0]
            jst_timestamp = parseDateTime(item[1])
            title = item[2]
            options.append([session_id,jst_timestamp,title])

        # サイドバーでセッションを選択
        selected_session_option = st.sidebar.selectbox(
            "過去チャットを選択", 
            options,
            index=0,
            format_func = lambda item: f"{item[2]}",
            key="session_select_box"
        )
        session_id = selected_session_option[0]
        message_timestamp = selected_session_option[1]
        title = selected_session_option[2]

        print(f"{selected_model},[{session_id}:{message_timestamp}:{title}],{st.session_state.current_chat_session_id}")

        # 新しいセッションIDが既存のものと異なる場合のみリセット
        if session_id == "-1":
            if st.session_state.messages_loaded_for_session is None and st.session_state.current_chat_session_id is not None:
                #新規で継続中
                st.session_state.messages = db.load_chat_history_for_session(oid, st.session_state.current_chat_session_id)
            else :
                st.session_state.current_chat_session_id = generate_unique_session_id()
                st.session_state.messages = []
                st.session_state.messages_loaded_for_session = None
        else:
            print(f"履歴ロード {session_id}")
            # 選択された既存のセッションIDをロード
            if st.session_state.current_chat_session_id != session_id:
                st.session_state.current_chat_session_id = session_id
                st.session_state.messages = db.load_chat_history_for_session(oid, st.session_state.current_chat_session_id)
                st.session_state.messages_loaded_for_session = session_id

        # セッションのリセットボタン
        if st.session_state.messages_loaded_for_session is None and st.session_state.current_chat_session_id is not None:
            if st.sidebar.button("リセット"):
                st.session_state.current_chat_session_id = None
                st.session_state.messages = []
                st.session_state.messages_loaded_for_session = None
                st.rerun()

        # 選択されたセッション履歴を削除
        if session_id != "-1":
            if st.sidebar.button("削除"):
                db.delete_user_session(oid, st.session_state.current_chat_session_id)
                st.session_state.current_chat_session_id = None
                st.session_state.messages = []
                st.session_state.messages_loaded_for_session = None
                st.rerun()

        # チャット履歴表示
        for message in st.session_state.messages:
            role = "assistant" if message["role"] == "assistant" else "user"
            with st.chat_message(role):
                if role == "assistant" :
                    st.markdown(message["message"])
                else:
                    st.text(message["message"])

        # チャット 入力待ち
        if prompt := st.chat_input("ここにメッセージを入力してください..."):
            # セッション チャット履歴追加
            st.session_state.messages.append({"role": "user", "message": prompt})
            # DB チャット履歴追加
            jstnow = datetime.datetime.now(jst_timezone).strftime('%Y-%m-%d %H')
            title = f"{jstnow} {prompt[:20]}"
            db.save_chat_message(oid, st.session_state.current_chat_session_id, "USER", prompt, title)

            with st.chat_message("user"):
                st.text(prompt)

            with st.chat_message("assistant"):
                with st.spinner("思考中..."):
                    chat_history = []
                    for message in st.session_state.messages:
                        if message["role"] == "user":
                            chat_history.append({"role": "USER", "message": message["message"]})
                        elif message["role"] == "assistant":
                            chat_history.append({"role": "CHATBOT", "message": message["message"]})

                    chat_request = CohereChatRequest(
                        message=prompt,
                        chat_history=chat_history[:-1] if chat_history else None,
                        max_tokens=3000,
                        temperature=0.7,
                        is_echo=True,
                        is_stream=False
                    )
                    serving_mode = OnDemandServingMode(model_id=selected_model)
                    chat_details = ChatDetails(
                        compartment_id=COMPARTMENT_ID,
                        chat_request=chat_request,
                        serving_mode=serving_mode
                    )
                    response = client.chat(chat_details)
                    bot_reply = response.data.chat_response.text

                    if bot_reply:
                        # セッション チャット履歴追加
                        st.session_state.messages.append({"role": "assistant", "message": bot_reply})
                        # DB チャット履歴追加
                        db.save_chat_message(oid, st.session_state.current_chat_session_id, "CHATBOT", bot_reply, title)
                        # 出力
                        st.markdown(bot_reply)


 


認証SSOは、GCP で、OAuth 2.0 クライアント ID を取得して 良い感じの 承認済みのリダイレクト URI を指定すればOkです。

./streamlit/secrets.toml

後は、ChatGPT とか、Gemeni に聞けばわかります。

こんな感じ Darkモードですね

 

Joomla templates by a4joomla