githubEdit

AI Swamp: Designing Smarter Python Agents Before They Mutate

This post details on different OOPs designs in building a robust framework that enables you to adapt to different model hosting platforms.

"""
src/agent/base.py

Base Class builders for Tarot Reading Agent.
"""

import ollama
from abc import ABC, abstractmethod 

from ...utils._schemas import TaroPrompt
from .schemas import DecodeMeter
from .client import setup_client, LLM_MODEL_ID

STREAM_MODE = False 
## Method 1
class TaroProvider(ABC):
    providers = {}
    
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        
        # every provider's registered name prefixed with "Taro" so get everything under index of 'o'
        cls.history = []
        provider = str(cls.__class__.__qualname__)[4:].lower()

        if provider not in TaroClient.providers:
            TaroClient.providers[provider] = []
        else:
            # Ensures all providers have the same number of agents by
            # Outputting a reminder to setup the subclasses for the providers.
            TaroClient.providers[provider].append(provider)
            
            if len(TaroClient.providers) > 1:
                # TODO: gets provider key & item with max number of subclasses. Then finds the missing ones in the other providers.
                pass 
            else:
                print("All providers are set and good to go.")
                
    @abstractmethod
    def preprocess(self, **kwargs) -> list[dict]:
        """ Subclasses must implement this to preprocess or validate input."""
        return super().preprocess(**kwargs)
    
    @abstractmethod
    def postprocess(self, output: ollama.ChatResponse):
        """ Subclasses can implement this to postprocess the model output if necessary or validate the outputs are correct. """
        return super().preprocess(**kwargs)
    
    @abstractmethod 
    def run(self, **kwargs) -> str: # type: ignore
        """ Main entrypoint to run the data pipeline and return model output."""
        return super().run(**kwargs)

class TaroClient(TaroProvider):
    agents: dict = {}

    def __init__(self, start_all: bool = False, **kwargs):
        
        if self.start_all is True:
            for agent, agent_cls in self.agents.items():
                setattr(self, agent, agent_cls)
                print(f"Initiated Agent: {agent}")
                
        self.history = []
        self.stream_mode = kwargs.get("stream_mode", STREAM_MODE)
        self.decode_options = kwargs.get("decode_options", DecodeMeter())
        self.model_id = kwargs.get("model_id", LLM_MODEL_ID)

    @abstractmethod
    def preprocess(self, **kwargs) -> list[dict]:
        """ Subclasses must implement this to preprocess or validate input."""
        return super().preprocess(**kwargs)
    
    @abstractmethod
    def postprocess(self, output: ollama.ChatResponse):
        """ Subclasses can implement this to postprocess the model output if necessary or validate the outputs are correct. """
        return super().preprocess(**kwargs)
    
    @abstractmethod 
    def run(self, **kwargs) -> str: # type: ignore
        """ Main entrypoint to run the data pipeline and return model output."""
        return super().run(**kwargs)

    def __init_subclass__(cls, prompt: TaroPrompt, **kwargs):
        super().__init_subclass__(**kwargs)

        cls.prompt = prompt
        cls.model_id = LLM_MODEL_ID
        cls.agents[cls.__sub_class__.__qualname__] = cls
        
        print(f"Succesfully registered new Jawa member, {cls.__qualname__}(id: {cls.prompt.label if isinstance(cls.prompt, TaroPrompt) else ''})to our SandCrawler!")
        
class TaroOllama(TaroProvider):
    """ Uses Ollama by default as LLM hosting provider. """
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.client = setup_client()

    @abstractmethod
    def preprocess(self, **kwargs) -> list[dict]:
        """ Subclasses must implement this to preprocess or validate input."""

        return list(self.prompt.build_chat_message(**inputs))
    
    @abstractmethod
    def postprocess(self, output: ollama.ChatResponse):
        """ Subclasses can implement this to postprocess the model output if necessary or validate the outputs are correct. """
        # self._postprocess(output)
        return output.message.get("content", None)

    def run(self, **kwargs) -> str: # type: ignore
        """ Main entrypoint to run the data pipeline and return model output."""

        try:
            message = self.preprocess(**kwargs)
            output = self.client.chat(
                model=self.model_id,
                messages=message,
                stream=self.stream_mode,
                options=self.decode_kwargs
            )

            return self.postprocess(output)
        except Exception as e:
            raise Exception(f"OllamaClient Error: {e}")

    @property
    def decode_kwargs(self):
        """ Returns the decoder kwargs in LLM. """
        return self.decode_options.model_dump_json()

# Method 2
class TaroClient(ABC):
    @abstractmethod
    def preprocess(self, **kwargs) -> list[dict]:
        """ Subclasses must implement this to preprocess or validate input."""

        return list(self.prompt.build_chat_message(**inputs))
    
    @abstractmethod
    def postprocess(self, output: ollama.ChatResponse):
        """ Subclasses can implement this to postprocess the model output if necessary or validate the outputs are correct. """
        # self._postprocess(output)
        return output.message.get("content", None)

    def run(self, **kwargs) -> str: # type: ignore
        """ Main entrypoint to run the data pipeline and return model output."""

        try:
            message = self.preprocess(**kwargs)
            output = self.client.chat(
                model=self.model_id,
                messages=message,
                stream=self.stream_mode,
                options=self.decode_kwargs
            )

            return self.postprocess(output)
        except Exception as e:
            raise Exception(f"OllamaClient Error: {e}")

    @property
    def decode_kwargs(self):
        """ Returns the decoder kwargs in LLM. """
        return self.decode_options.model_dump_json()

    def __init_subclass__(cls, prompt: TaroPrompt, **kwargs):
        super().__init_subclass__(**kwargs)

        cls.prompt = prompt
        cls.model_id = LLM_MODEL_ID
        cls.client = setup_client()
        cls.decode_options = DecodeMeter()
        cls.stream_mode = STREAM_MODE
        
        print(f"Succesfully registered new Jawa member, {cls.__qualname__}(id: {cls.prompt.label if isinstance(cls.prompt, TaroPrompt) else ''})to our SandCrawler!")

Last updated

Was this helpful?