import json
import websockets
from ..types import (
PriceFeedResponse,
PriceFeedUpdatesResponse,
PairInfoFeed,
FeedV3PriceResponse,
LazerPriceFeedResponse,
)
from typing import List, Callable, Optional
import requests
from pydantic import ValidationError
from ..config import AVANTIS_SOCKET_API, AVANTIS_FEED_V3_URL, PYTH_LAZER_SSE_URL
import asyncio
from concurrent.futures import ThreadPoolExecutor
import aiohttp
[docs]
class FeedClient:
"""
Client for interacting with the Pyth price feed websocket.
"""
def __init__(
self,
ws_url="wss://hermes.pyth.network/ws",
on_error=None,
on_close=None,
hermes_url="https://hermes.pyth.network/v2/updates/price/latest",
socket_api: str = AVANTIS_SOCKET_API,
pair_fetcher: Callable = None,
feed_v3_url: str = AVANTIS_FEED_V3_URL,
lazer_sse_url: str = PYTH_LAZER_SSE_URL,
):
"""
Constructor for the FeedClient class.
Args:
ws_url: Optional - The websocket URL to connect to (Pyth Hermes).
on_error: Optional callback for handling websocket/SSE errors.
on_close: Optional callback for handling websocket/SSE close events.
hermes_url: Optional - The Hermes HTTP API URL.
socket_api: Optional - The Avantis socket API URL.
pair_fetcher: Optional - Custom pair fetcher function.
feed_v3_url: Optional - The feed-v3 API URL for price update data.
lazer_sse_url: Optional - The Pyth Lazer SSE URL for real-time prices.
"""
if (
ws_url is not None
and not ws_url.startswith("ws://")
and not ws_url.startswith("wss://")
):
raise ValueError("ws_url must start with ws:// or wss://")
self.ws_url = ws_url
self.hermes_url = hermes_url
self.feed_v3_url = feed_v3_url
self.lazer_sse_url = lazer_sse_url
self.pair_feeds = {}
self.feed_pairs = {}
self.price_feed_callbacks = {}
self.lazer_callbacks = {}
self._socket = None
self._connected = False
self._lazer_connected = False
self._on_error = on_error
self._on_close = on_close
self.socket_api = socket_api
self.pair_fetcher = pair_fetcher or self.default_pair_fetcher
self.load_pair_feeds()
[docs]
async def listen_for_price_updates(self):
"""
Listens for price updates from the Pyth price feed websocket.
When a price update is received, the registered callbacks will
be called with the updated price feed data.
Raises:
Exception: If an error occurs while listening for price updates.
"""
try:
async with websockets.connect(self.ws_url) as websocket:
self._socket = websocket
self._connected = True
await websocket.send(
json.dumps(
{
"type": "subscribe",
"ids": list(self.price_feed_callbacks.keys()),
}
)
)
while True:
try:
message = await websocket.recv()
data = json.loads(message)
if data["type"] == "price_update":
price_feed_id = data["price_feed"]["id"]
if price_feed_id in self.price_feed_callbacks:
pair_string = self.get_pair_from_feed_id(price_feed_id)
data["price_feed"]["pair"] = pair_string
for callback in self.price_feed_callbacks[
price_feed_id
]:
callback(PriceFeedResponse(**data["price_feed"]))
except websockets.exceptions.ConnectionClosed as e:
if self._on_close:
self._on_close(e)
else:
print(f"Connection closed with error: {e}")
break
except Exception as e:
if self._on_error:
self._on_error(e)
else:
raise e
except Exception as e:
if self._on_error:
self._on_error(e)
else:
raise e
[docs]
async def default_pair_fetcher(self) -> List[dict]:
"""
Default pair fetcher that retrieves data from the Avantis API.
Returns:
A list of validated trading pairs.
Raises:
ValueError if API response is invalid.
"""
if not self.socket_api:
raise ValueError("socket_api is not set")
try:
response = requests.get(self.socket_api)
response.raise_for_status()
result = response.json()
pairs = result["data"]["pairInfos"].values()
return pairs
except (requests.RequestException, ValidationError) as e:
print(f"Error fetching pair feeds: {e}")
return []
[docs]
def load_pair_feeds(self):
"""
Loads the pair feeds dynamically using the provided pair_fetcher function.
"""
try:
if self.pair_feeds:
return
try:
asyncio.get_running_loop()
except RuntimeError:
asyncio.set_event_loop(asyncio.new_event_loop())
with ThreadPoolExecutor() as executor:
future = executor.submit(lambda: asyncio.run(self.pair_fetcher()))
pairs = future.result()
if not pairs:
raise ValueError("Fetched pair feed data is empty or invalid.")
if isinstance(pairs, dict):
pairs = list(pairs.values())
else:
pairs = list(pairs)
if hasattr(pairs[0], "model_dump_json"):
pairs = [json.loads(pair.model_dump_json()) for pair in pairs]
validated_pairs = [PairInfoFeed.model_validate(pair) for pair in pairs]
self.pair_feeds = {
f"{pair.from_}/{pair.to}": {"id": pair.feed.feed_id}
for pair in validated_pairs
}
self.feed_pairs = {
pair.feed.feed_id: f"{pair.from_}/{pair.to}" for pair in validated_pairs
}
except Exception as e:
print(f"Failed to load pair feeds: {e}")
[docs]
def get_pair_from_feed_id(self, feed_id):
"""
Retrieves the pair string from the feed id.
Args:
feed_id: The feed id to retrieve the pair string for.
Returns:
The pair string.
"""
if not feed_id.startswith("0x"):
feed_id = "0x" + feed_id
return self.feed_pairs.get(feed_id)
[docs]
def register_price_feed_callback(self, identifier, callback):
"""
Registers a callback for price feed updates.
Args:
identifier: The identifier of the price feed to register the callback for.
callback: The callback to register.
Raises:
ValueError: If the identifier is unknown.
"""
if identifier in self.pair_feeds:
price_feed_id = self.pair_feeds[identifier]["id"]
elif identifier in self.feed_pairs:
price_feed_id = identifier
elif identifier in self.price_feed_callbacks:
price_feed_id = identifier
else:
raise ValueError(f"Unknown identifier: {identifier}")
if price_feed_id.startswith("0x"):
price_feed_id = price_feed_id[2:]
if price_feed_id not in self.price_feed_callbacks:
self.price_feed_callbacks[price_feed_id] = []
self.price_feed_callbacks[price_feed_id].append(callback)
[docs]
def unregister_price_feed_callback(self, identifier, callback):
"""
Unregisters a callback for price feed updates.
Args:
identifier: The identifier of the price feed to unregister the callback for.
callback: The callback to unregister.
"""
if identifier in self.pair_feeds:
price_feed_id = self.pair_feeds[identifier]["id"]
else:
price_feed_id = identifier
if price_feed_id in self.price_feed_callbacks:
self.price_feed_callbacks[price_feed_id].remove(callback)
if not self.price_feed_callbacks[price_feed_id]:
del self.price_feed_callbacks[price_feed_id]
[docs]
async def get_latest_price_updates(self, identifiers: List[str]):
"""
Retrieves the latest price updates for the specified feed ids.
Args:
feedIds: The list of feed ids to retrieve the latest price updates for.
Returns:
A PriceFeedUpdatesResponse object containing the latest price updates.
"""
if not self.pair_feeds:
self.load_pair_feeds()
url = self.hermes_url
feedIds = []
for identifier in identifiers:
if identifier in self.pair_feeds:
price_feed_id = self.pair_feeds[identifier]["id"]
elif identifier in self.feed_pairs:
price_feed_id = identifier
else:
raise ValueError(f"Unknown identifier: {identifier}")
if price_feed_id.startswith("0x"):
price_feed_id = price_feed_id[2:]
feedIds.append(price_feed_id)
params = {"ids[]": feedIds}
response = requests.get(url, params=params)
if response.status_code == 200:
data = response.json()
for i in range(len(data["parsed"])):
data["parsed"][i] = PriceFeedResponse(**data["parsed"][i])
return PriceFeedUpdatesResponse(**data)
else:
response.raise_for_status()
[docs]
async def get_price_update_data(self, pair_index: int) -> FeedV3PriceResponse:
"""
Retrieves price update data from the feed-v3 API for a specific pair.
This returns both core (Pyth Hermes) and pro (Pyth Lazer) price data,
including the priceUpdateData bytes needed for contract calls.
Args:
pair_index: The pair index to get price update data for.
Returns:
A FeedV3PriceResponse containing core and pro price data.
Raises:
requests.HTTPError: If the API request fails.
"""
url = f"{self.feed_v3_url}/v2/pairs/{pair_index}/price-update-data"
response = requests.get(url, timeout=10)
response.raise_for_status()
data = response.json()
return FeedV3PriceResponse(**data)
[docs]
async def get_latest_lazer_price(
self, lazer_feed_ids: List[int]
) -> LazerPriceFeedResponse:
"""
Retrieves the latest prices from the Pyth Lazer API.
Args:
lazer_feed_ids: List of Lazer feed IDs to get prices for.
Returns:
A LazerPriceFeedResponse containing the latest prices.
Raises:
requests.HTTPError: If the API request fails.
"""
params = "&".join([f"price_feed_ids={fid}" for fid in lazer_feed_ids])
url = f"{self.lazer_sse_url.replace('/stream', '/latest_price')}?{params}"
response = requests.get(url, timeout=10)
response.raise_for_status()
data = response.json()
return LazerPriceFeedResponse(**data)
[docs]
async def listen_for_lazer_price_updates(
self,
lazer_feed_ids: List[int],
callback: Callable[[LazerPriceFeedResponse], None],
):
"""
Listens for real-time price updates from the Pyth Lazer SSE stream.
This is the Pyth Pro alternative to the WebSocket-based listen_for_price_updates.
Args:
lazer_feed_ids: List of Lazer feed IDs to subscribe to.
callback: Callback function to handle price updates.
Raises:
Exception: If an error occurs while listening for price updates.
"""
params = "&".join([f"price_feed_ids={fid}" for fid in lazer_feed_ids])
url = f"{self.lazer_sse_url}?{params}"
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
self._lazer_connected = True
async for line in response.content:
line = line.decode("utf-8").strip()
if line.startswith("data:"):
try:
data = json.loads(line[5:].strip())
price_response = LazerPriceFeedResponse(**data)
callback(price_response)
except json.JSONDecodeError as e:
if self._on_error:
self._on_error(e)
except ValidationError as e:
if self._on_error:
self._on_error(e)
except aiohttp.ClientError as e:
self._lazer_connected = False
if self._on_error:
self._on_error(e)
else:
raise e
except Exception as e:
self._lazer_connected = False
if self._on_close:
self._on_close(e)
else:
raise e