Source code for avantis_trader_sdk.feed.feed_client

import json
import websockets
from ..types import PriceFeedResponse, PriceFeedUpdatesResponse, PairInfoFeed
from typing import List, Callable
import requests
from pydantic import ValidationError
from ..config import AVANTIS_SOCKET_API
import asyncio
from concurrent.futures import ThreadPoolExecutor


[docs] class FeedClient: """ Client for interacting with the Pyth price feed websocket. """ def __init__( self, ws_url="wss://hermes.pyth.network/ws", on_error=None, on_close=None, hermes_url="https://hermes.pyth.network/v2/updates/price/latest", pair_fetcher: Callable = None, ): """ Constructor for the FeedClient class. Args: ws_url: Optional - The websocket URL to connect to. on_error: Optional callback for handling websocket errors. on_close: Optional callback for handling websocket close events. """ if ( ws_url is not None and not ws_url.startswith("ws://") and not ws_url.startswith("wss://") ): raise ValueError("ws_url must start with ws:// or wss://") self.ws_url = ws_url self.hermes_url = hermes_url self.pair_feeds = {} self.feed_pairs = {} self.price_feed_callbacks = {} self._socket = None self._connected = False self._on_error = on_error self._on_close = on_close self.pair_fetcher = pair_fetcher or self.default_pair_fetcher self._load_pair_feeds()
[docs] async def listen_for_price_updates(self): """ Listens for price updates from the Pyth price feed websocket. When a price update is received, the registered callbacks will be called with the updated price feed data. Raises: Exception: If an error occurs while listening for price updates. """ try: async with websockets.connect(self.ws_url) as websocket: self._socket = websocket self._connected = True await websocket.send( json.dumps( { "type": "subscribe", "ids": list(self.price_feed_callbacks.keys()), } ) ) while True: try: message = await websocket.recv() data = json.loads(message) if data["type"] == "price_update": price_feed_id = data["price_feed"]["id"] if price_feed_id in self.price_feed_callbacks: pair_string = self.get_pair_from_feed_id(price_feed_id) data["price_feed"]["pair"] = pair_string for callback in self.price_feed_callbacks[ price_feed_id ]: callback(PriceFeedResponse(**data["price_feed"])) except websockets.exceptions.ConnectionClosed as e: if self._on_close: self._on_close(e) else: print(f"Connection closed with error: {e}") break except Exception as e: if self._on_error: self._on_error(e) else: raise e except Exception as e: if self._on_error: self._on_error(e) else: raise e
[docs] async def default_pair_fetcher(self) -> List[dict]: """ Default pair fetcher that retrieves data from the Avantis API. Returns: A list of validated trading pairs. Raises: ValueError if API response is invalid. """ if not AVANTIS_SOCKET_API: raise ValueError("AVANTIS_SOCKET_API is not set") try: response = requests.get(AVANTIS_SOCKET_API) response.raise_for_status() result = response.json() pairs = result["data"]["pairInfos"].values() return pairs except (requests.RequestException, ValidationError) as e: print(f"Error fetching pair feeds: {e}") return []
def _load_pair_feeds(self): """ Loads the pair feeds dynamically using the provided pair_fetcher function. """ try: try: asyncio.get_running_loop() except RuntimeError: asyncio.set_event_loop(asyncio.new_event_loop()) with ThreadPoolExecutor() as executor: future = executor.submit(lambda: asyncio.run(self.pair_fetcher())) pairs = future.result() if not pairs: raise ValueError("Fetched pair feed data is empty or invalid.") if isinstance(pairs, dict): pairs = list(pairs.values()) else: pairs = list(pairs) if hasattr(pairs[0], "model_dump_json"): pairs = [json.loads(pair.model_dump_json()) for pair in pairs] validated_pairs = [PairInfoFeed.model_validate(pair) for pair in pairs] self.pair_feeds = { f"{pair.from_}/{pair.to}": {"id": pair.feed.feed_id} for pair in validated_pairs } self.feed_pairs = { pair.feed.feed_id: f"{pair.from_}/{pair.to}" for pair in validated_pairs } except Exception as e: print(f"Failed to load pair feeds: {e}")
[docs] def get_pair_from_feed_id(self, feed_id): """ Retrieves the pair string from the feed id. Args: feed_id: The feed id to retrieve the pair string for. Returns: The pair string. """ if not feed_id.startswith("0x"): feed_id = "0x" + feed_id return self.feed_pairs.get(feed_id)
[docs] def register_price_feed_callback(self, identifier, callback): """ Registers a callback for price feed updates. Args: identifier: The identifier of the price feed to register the callback for. callback: The callback to register. Raises: ValueError: If the identifier is unknown. """ if identifier in self.pair_feeds: price_feed_id = self.pair_feeds[identifier]["id"] elif identifier in self.feed_pairs: price_feed_id = identifier elif identifier in self.price_feed_callbacks: price_feed_id = identifier else: raise ValueError(f"Unknown identifier: {identifier}") if price_feed_id.startswith("0x"): price_feed_id = price_feed_id[2:] if price_feed_id not in self.price_feed_callbacks: self.price_feed_callbacks[price_feed_id] = [] self.price_feed_callbacks[price_feed_id].append(callback)
[docs] def unregister_price_feed_callback(self, identifier, callback): """ Unregisters a callback for price feed updates. Args: identifier: The identifier of the price feed to unregister the callback for. callback: The callback to unregister. """ if identifier in self.pair_feeds: price_feed_id = self.pair_feeds[identifier]["id"] else: price_feed_id = identifier if price_feed_id in self.price_feed_callbacks: self.price_feed_callbacks[price_feed_id].remove(callback) if not self.price_feed_callbacks[price_feed_id]: del self.price_feed_callbacks[price_feed_id]
[docs] async def get_latest_price_updates(self, identifiers: List[str]): """ Retrieves the latest price updates for the specified feed ids. Args: feedIds: The list of feed ids to retrieve the latest price updates for. Returns: A PriceFeedUpdatesResponse object containing the latest price updates. """ url = self.hermes_url feedIds = [] for identifier in identifiers: if identifier in self.pair_feeds: price_feed_id = self.pair_feeds[identifier]["id"] elif identifier in self.feed_pairs: price_feed_id = identifier else: raise ValueError(f"Unknown identifier: {identifier}") if price_feed_id.startswith("0x"): price_feed_id = price_feed_id[2:] feedIds.append(price_feed_id) params = {"ids[]": feedIds} response = requests.get(url, params=params) if response.status_code == 200: data = response.json() for i in range(len(data["parsed"])): data["parsed"][i] = PriceFeedResponse(**data["parsed"][i]) return PriceFeedUpdatesResponse(**data) else: response.raise_for_status()