AI Swamp: Designing Smarter Python Agents Before They Mutate
This post details on different OOPs designs in building a robust framework that enables you to adapt to different model hosting platforms.
"""
src/agent/base.py
Base Class builders for Tarot Reading Agent.
"""
import ollama
from abc import ABC, abstractmethod
from ...utils._schemas import TaroPrompt
from .schemas import DecodeMeter
from .client import setup_client, LLM_MODEL_ID
STREAM_MODE = False
## Method 1
class TaroProvider(ABC):
providers = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# every provider's registered name prefixed with "Taro" so get everything under index of 'o'
cls.history = []
provider = str(cls.__class__.__qualname__)[4:].lower()
if provider not in TaroClient.providers:
TaroClient.providers[provider] = []
else:
# Ensures all providers have the same number of agents by
# Outputting a reminder to setup the subclasses for the providers.
TaroClient.providers[provider].append(provider)
if len(TaroClient.providers) > 1:
# TODO: gets provider key & item with max number of subclasses. Then finds the missing ones in the other providers.
pass
else:
print("All providers are set and good to go.")
@abstractmethod
def preprocess(self, **kwargs) -> list[dict]:
""" Subclasses must implement this to preprocess or validate input."""
return super().preprocess(**kwargs)
@abstractmethod
def postprocess(self, output: ollama.ChatResponse):
""" Subclasses can implement this to postprocess the model output if necessary or validate the outputs are correct. """
return super().preprocess(**kwargs)
@abstractmethod
def run(self, **kwargs) -> str: # type: ignore
""" Main entrypoint to run the data pipeline and return model output."""
return super().run(**kwargs)
class TaroClient(TaroProvider):
agents: dict = {}
def __init__(self, start_all: bool = False, **kwargs):
if self.start_all is True:
for agent, agent_cls in self.agents.items():
setattr(self, agent, agent_cls)
print(f"Initiated Agent: {agent}")
self.history = []
self.stream_mode = kwargs.get("stream_mode", STREAM_MODE)
self.decode_options = kwargs.get("decode_options", DecodeMeter())
self.model_id = kwargs.get("model_id", LLM_MODEL_ID)
@abstractmethod
def preprocess(self, **kwargs) -> list[dict]:
""" Subclasses must implement this to preprocess or validate input."""
return super().preprocess(**kwargs)
@abstractmethod
def postprocess(self, output: ollama.ChatResponse):
""" Subclasses can implement this to postprocess the model output if necessary or validate the outputs are correct. """
return super().preprocess(**kwargs)
@abstractmethod
def run(self, **kwargs) -> str: # type: ignore
""" Main entrypoint to run the data pipeline and return model output."""
return super().run(**kwargs)
def __init_subclass__(cls, prompt: TaroPrompt, **kwargs):
super().__init_subclass__(**kwargs)
cls.prompt = prompt
cls.model_id = LLM_MODEL_ID
cls.agents[cls.__sub_class__.__qualname__] = cls
print(f"Succesfully registered new Jawa member, {cls.__qualname__}(id: {cls.prompt.label if isinstance(cls.prompt, TaroPrompt) else ''})to our SandCrawler!")
class TaroOllama(TaroProvider):
""" Uses Ollama by default as LLM hosting provider. """
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.client = setup_client()
@abstractmethod
def preprocess(self, **kwargs) -> list[dict]:
""" Subclasses must implement this to preprocess or validate input."""
return list(self.prompt.build_chat_message(**inputs))
@abstractmethod
def postprocess(self, output: ollama.ChatResponse):
""" Subclasses can implement this to postprocess the model output if necessary or validate the outputs are correct. """
# self._postprocess(output)
return output.message.get("content", None)
def run(self, **kwargs) -> str: # type: ignore
""" Main entrypoint to run the data pipeline and return model output."""
try:
message = self.preprocess(**kwargs)
output = self.client.chat(
model=self.model_id,
messages=message,
stream=self.stream_mode,
options=self.decode_kwargs
)
return self.postprocess(output)
except Exception as e:
raise Exception(f"OllamaClient Error: {e}")
@property
def decode_kwargs(self):
""" Returns the decoder kwargs in LLM. """
return self.decode_options.model_dump_json()
# Method 2
class TaroClient(ABC):
@abstractmethod
def preprocess(self, **kwargs) -> list[dict]:
""" Subclasses must implement this to preprocess or validate input."""
return list(self.prompt.build_chat_message(**inputs))
@abstractmethod
def postprocess(self, output: ollama.ChatResponse):
""" Subclasses can implement this to postprocess the model output if necessary or validate the outputs are correct. """
# self._postprocess(output)
return output.message.get("content", None)
def run(self, **kwargs) -> str: # type: ignore
""" Main entrypoint to run the data pipeline and return model output."""
try:
message = self.preprocess(**kwargs)
output = self.client.chat(
model=self.model_id,
messages=message,
stream=self.stream_mode,
options=self.decode_kwargs
)
return self.postprocess(output)
except Exception as e:
raise Exception(f"OllamaClient Error: {e}")
@property
def decode_kwargs(self):
""" Returns the decoder kwargs in LLM. """
return self.decode_options.model_dump_json()
def __init_subclass__(cls, prompt: TaroPrompt, **kwargs):
super().__init_subclass__(**kwargs)
cls.prompt = prompt
cls.model_id = LLM_MODEL_ID
cls.client = setup_client()
cls.decode_options = DecodeMeter()
cls.stream_mode = STREAM_MODE
print(f"Succesfully registered new Jawa member, {cls.__qualname__}(id: {cls.prompt.label if isinstance(cls.prompt, TaroPrompt) else ''})to our SandCrawler!")
Last updated
Was this helpful?

