Skip to content

Commit b2273ae

Browse files
committed
refactor: improve typing
1 parent 46db3cc commit b2273ae

14 files changed

Lines changed: 97 additions & 71 deletions

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Install dependencies and run the service:
1414
```sh
1515
$ poetry env use $(which python) # point Poetry to the pyenv python shim
1616
$ poetry install
17-
$ poetry run pyth-observer
17+
$ poetry run pyth-observe --config sample.config.yaml --publishers sample.publishers.yaml --coingecko-mapping sample.coingecko.yaml
1818
```
1919

2020
Use `poetry run pyth-observer --help` for documentation on arguments and environment variables.

pyth_observer/__init__.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import os
3-
from typing import Any, Dict, List, Tuple
3+
from typing import Any, Dict, List, Literal, Tuple
44

55
from base58 import b58decode
66
from loguru import logger
@@ -18,6 +18,7 @@
1818
from throttler import Throttler
1919

2020
import pyth_observer.health_server as health_server
21+
from pyth_observer.check import State
2122
from pyth_observer.check.price_feed import PriceFeedState
2223
from pyth_observer.check.publisher import PublisherState
2324
from pyth_observer.coingecko import Symbol, get_coingecko_prices
@@ -35,7 +36,9 @@
3536
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
3637

3738

38-
def get_solana_urls(network) -> Tuple[str, str]:
39+
def get_solana_urls(
40+
network: Literal["devnet", "testnet", "mainnet", "pythtest", "pythnet"]
41+
) -> Tuple[str, str]:
3942
"""
4043
Helper for getting the correct urls for the PythClient
4144
"""
@@ -55,7 +58,7 @@ def __init__(
5558
config: Dict[str, Any],
5659
publishers: Dict[str, Publisher],
5760
coingecko_mapping: Dict[str, Symbol],
58-
):
61+
) -> None:
5962
self.config = config
6063
self.dispatch = Dispatch(config, publishers)
6164
self.publishers = publishers
@@ -77,9 +80,9 @@ def __init__(
7780
config=config,
7881
)
7982

80-
async def run(self):
83+
async def run(self) -> None:
8184
# global states
82-
states = []
85+
states: List[State] = []
8386
while True:
8487
try:
8588
logger.info("Running checks")
@@ -91,7 +94,7 @@ async def run(self):
9194
health_server.observer_ready = True
9295

9396
processed_feeds = 0
94-
active_publishers_by_symbol = {}
97+
active_publishers_by_symbol: Dict[str, Dict[str, Any]] = {}
9598

9699
for product in products:
97100
# Skip tombstone accounts with blank metadata
@@ -104,7 +107,7 @@ async def run(self):
104107
# For each product, we build a list of price feed states (one
105108
# for each price account) and a list of publisher states (one
106109
# for each publisher).
107-
states = []
110+
states: List[State] = []
108111
price_accounts = await self.get_pyth_prices(product)
109112

110113
crosschain_price = crosschain_prices.get(
@@ -249,7 +252,9 @@ async def get_pyth_prices(
249252
).inc()
250253
raise
251254

252-
async def get_coingecko_prices(self):
255+
async def get_coingecko_prices(
256+
self,
257+
) -> Tuple[Dict[str, float], Dict[str, int]]:
253258
logger.debug("Fetching CoinGecko prices...")
254259

255260
try:

pyth_observer/check/price_feed.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import time
22
from dataclasses import dataclass
33
from datetime import datetime
4-
from typing import Dict, Optional, Protocol, runtime_checkable
4+
from typing import Any, Dict, Optional, Protocol, runtime_checkable
55
from zoneinfo import ZoneInfo
66

77
import arrow
@@ -33,7 +33,7 @@ class PriceFeedState:
3333

3434
@runtime_checkable
3535
class PriceFeedCheck(Protocol):
36-
def __init__(self, state: PriceFeedState, config: PriceFeedCheckConfig):
36+
def __init__(self, state: PriceFeedState, config: PriceFeedCheckConfig) -> None:
3737
...
3838

3939
def state(self) -> PriceFeedState:
@@ -42,12 +42,12 @@ def state(self) -> PriceFeedState:
4242
def run(self) -> bool:
4343
...
4444

45-
def error_message(self) -> dict:
45+
def error_message(self) -> Dict[str, Any]:
4646
...
4747

4848

4949
class PriceFeedOfflineCheck(PriceFeedCheck):
50-
def __init__(self, state: PriceFeedState, config: PriceFeedCheckConfig):
50+
def __init__(self, state: PriceFeedState, config: PriceFeedCheckConfig) -> None:
5151
self.__state = state
5252
self.__max_slot_distance: int = int(config["max_slot_distance"])
5353
self.__abandoned_slot_distance: int = int(config["abandoned_slot_distance"])
@@ -79,7 +79,7 @@ def run(self) -> bool:
7979
# Fail
8080
return False
8181

82-
def error_message(self) -> dict:
82+
def error_message(self) -> Dict[str, Any]:
8383
distance = self.__state.latest_block_slot - self.__state.latest_trading_slot
8484
return {
8585
"msg": f"{self.__state.symbol} is offline (either non-trading/stale). Last update {distance} slots ago.",
@@ -91,7 +91,7 @@ def error_message(self) -> dict:
9191

9292

9393
class PriceFeedCoinGeckoCheck(PriceFeedCheck):
94-
def __init__(self, state: PriceFeedState, config: PriceFeedCheckConfig):
94+
def __init__(self, state: PriceFeedState, config: PriceFeedCheckConfig) -> None:
9595
self.__state = state
9696
self.__max_deviation: int = int(config["max_deviation"]) # Percentage
9797
self.__max_staleness: int = int(config["max_staleness"]) # Seconds
@@ -124,7 +124,7 @@ def run(self) -> bool:
124124
# Fail
125125
return False
126126

127-
def error_message(self) -> dict:
127+
def error_message(self) -> Dict[str, Any]:
128128
return {
129129
"msg": f"{self.__state.symbol} is too far from Coingecko's price.",
130130
"type": "PriceFeedCoinGeckoCheck",
@@ -135,7 +135,7 @@ def error_message(self) -> dict:
135135

136136

137137
class PriceFeedConfidenceIntervalCheck(PriceFeedCheck):
138-
def __init__(self, state: PriceFeedState, config: PriceFeedCheckConfig):
138+
def __init__(self, state: PriceFeedState, config: PriceFeedCheckConfig) -> None:
139139
self.__state = state
140140
self.__min_confidence_interval: int = int(config["min_confidence_interval"])
141141

@@ -154,7 +154,7 @@ def run(self) -> bool:
154154
# Fail
155155
return False
156156

157-
def error_message(self) -> dict:
157+
def error_message(self) -> Dict[str, Any]:
158158
return {
159159
"msg": f"{self.__state.symbol} confidence interval is too low.",
160160
"type": "PriceFeedConfidenceIntervalCheck",
@@ -164,7 +164,7 @@ def error_message(self) -> dict:
164164

165165

166166
class PriceFeedCrossChainOnlineCheck(PriceFeedCheck):
167-
def __init__(self, state: PriceFeedState, config: PriceFeedCheckConfig):
167+
def __init__(self, state: PriceFeedState, config: PriceFeedCheckConfig) -> None:
168168
self.__state = state
169169
self.__max_staleness: int = int(config["max_staleness"])
170170

@@ -204,7 +204,7 @@ def run(self) -> bool:
204204
# Fail
205205
return False
206206

207-
def error_message(self) -> dict:
207+
def error_message(self) -> Dict[str, Any]:
208208
if self.__state.crosschain_price:
209209
publish_time = arrow.get(self.__state.crosschain_price["publish_time"])
210210
else:
@@ -219,7 +219,7 @@ def error_message(self) -> dict:
219219

220220

221221
class PriceFeedCrossChainDeviationCheck(PriceFeedCheck):
222-
def __init__(self, state: PriceFeedState, config: PriceFeedCheckConfig):
222+
def __init__(self, state: PriceFeedState, config: PriceFeedCheckConfig) -> None:
223223
self.__state = state
224224
self.__max_deviation: int = int(config["max_deviation"])
225225
self.__max_staleness: int = int(config["max_staleness"])
@@ -262,7 +262,7 @@ def run(self) -> bool:
262262
# Fail
263263
return False
264264

265-
def error_message(self) -> dict:
265+
def error_message(self) -> Dict[str, Any]:
266266
# It can never happen because of the check logic but linter could not understand it.
267267
price = (
268268
self.__state.crosschain_price["price"]

pyth_observer/check/publisher.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import defaultdict, deque
33
from dataclasses import asdict, dataclass
44
from datetime import datetime
5-
from typing import Dict, Protocol, runtime_checkable
5+
from typing import Any, Dict, Protocol, runtime_checkable
66
from zoneinfo import ZoneInfo
77

88
from loguru import logger
@@ -54,7 +54,7 @@ class PublisherState:
5454

5555
@runtime_checkable
5656
class PublisherCheck(Protocol):
57-
def __init__(self, state: PublisherState, config: PublisherCheckConfig):
57+
def __init__(self, state: PublisherState, config: PublisherCheckConfig) -> None:
5858
...
5959

6060
def state(self) -> PublisherState:
@@ -63,12 +63,12 @@ def state(self) -> PublisherState:
6363
def run(self) -> bool:
6464
...
6565

66-
def error_message(self) -> dict:
66+
def error_message(self) -> Dict[str, Any]:
6767
...
6868

6969

7070
class PublisherWithinAggregateConfidenceCheck(PublisherCheck):
71-
def __init__(self, state: PublisherState, config: PublisherCheckConfig):
71+
def __init__(self, state: PublisherState, config: PublisherCheckConfig) -> None:
7272
self.__state = state
7373
self.__max_interval_distance: int = int(config["max_interval_distance"])
7474

@@ -103,7 +103,7 @@ def run(self) -> bool:
103103
# Fail
104104
return False
105105

106-
def error_message(self) -> dict:
106+
def error_message(self) -> Dict[str, Any]:
107107
diff = self.__state.price - self.__state.price_aggregate
108108
intervals_away = abs(diff / self.__state.confidence_interval_aggregate)
109109
return {
@@ -117,7 +117,7 @@ def error_message(self) -> dict:
117117

118118

119119
class PublisherConfidenceIntervalCheck(PublisherCheck):
120-
def __init__(self, state: PublisherState, config: PublisherCheckConfig):
120+
def __init__(self, state: PublisherState, config: PublisherCheckConfig) -> None:
121121
self.__state = state
122122
self.__min_confidence_interval: int = int(config["min_confidence_interval"])
123123

@@ -141,7 +141,7 @@ def run(self) -> bool:
141141
# Fail
142142
return False
143143

144-
def error_message(self) -> dict:
144+
def error_message(self) -> Dict[str, Any]:
145145
return {
146146
"msg": f"{self.__state.publisher_name} confidence interval is too tight.",
147147
"type": "PublisherConfidenceIntervalCheck",
@@ -153,7 +153,7 @@ def error_message(self) -> dict:
153153

154154

155155
class PublisherOfflineCheck(PublisherCheck):
156-
def __init__(self, state: PublisherState, config: PublisherCheckConfig):
156+
def __init__(self, state: PublisherState, config: PublisherCheckConfig) -> None:
157157
self.__state = state
158158
self.__max_slot_distance: int = int(config["max_slot_distance"])
159159
self.__abandoned_slot_distance: int = int(config["abandoned_slot_distance"])
@@ -182,7 +182,7 @@ def run(self) -> bool:
182182
# Fail
183183
return False
184184

185-
def error_message(self) -> dict:
185+
def error_message(self) -> Dict[str, Any]:
186186
distance = self.__state.latest_block_slot - self.__state.slot
187187
return {
188188
"msg": f"{self.__state.publisher_name} hasn't published recently for {distance} slots.",
@@ -195,7 +195,7 @@ def error_message(self) -> dict:
195195

196196

197197
class PublisherPriceCheck(PublisherCheck):
198-
def __init__(self, state: PublisherState, config: PublisherCheckConfig):
198+
def __init__(self, state: PublisherState, config: PublisherCheckConfig) -> None:
199199
self.__state = state
200200
self.__max_aggregate_distance: int = int(config["max_aggregate_distance"]) # %
201201
self.__max_slot_distance: int = int(config["max_slot_distance"]) # Slots
@@ -230,7 +230,7 @@ def run(self) -> bool:
230230
# Fail
231231
return False
232232

233-
def error_message(self) -> dict:
233+
def error_message(self) -> Dict[str, Any]:
234234
deviation = (self.ci_adjusted_price_diff() / self.__state.price_aggregate) * 100
235235
return {
236236
"msg": f"{self.__state.publisher_name} price is too far from aggregate price.",
@@ -250,7 +250,7 @@ def ci_adjusted_price_diff(self) -> float:
250250

251251

252252
class PublisherStalledCheck(PublisherCheck):
253-
def __init__(self, state: PublisherState, config: PublisherCheckConfig):
253+
def __init__(self, state: PublisherState, config: PublisherCheckConfig) -> None:
254254
self.__state = state
255255
self.__stall_time_limit: int = int(
256256
config["stall_time_limit"]
@@ -313,7 +313,7 @@ def run(self) -> bool:
313313

314314
return not result.is_stalled
315315

316-
def error_message(self) -> dict:
316+
def error_message(self) -> Dict[str, Any]:
317317
stall_duration = f"{self.__last_analysis.duration:.1f} seconds"
318318
return {
319319
"msg": f"{self.__state.publisher_name} has been publishing the same price of {self.__state.symbol} for {stall_duration}",

pyth_observer/check/stall_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
stall_time_limit: float,
6565
noise_threshold: float = 1e-4,
6666
min_noise_samples: int = 5,
67-
):
67+
) -> None:
6868
"""
6969
Initialize stall detector.
7070

pyth_observer/cli.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import os
33
import sys
4+
from typing import Any, Dict
45

56
import click
67
import yaml
@@ -37,11 +38,13 @@
3738
envvar="PROMETHEUS_PORT",
3839
default="9001",
3940
)
40-
def run(config, publishers, coingecko_mapping, prometheus_port):
41-
config_ = yaml.safe_load(open(config, "r"))
41+
def run(
42+
config: str, publishers: str, coingecko_mapping: str, prometheus_port: str
43+
) -> None:
44+
config_: Dict[str, Any] = yaml.safe_load(open(config, "r")) # type: ignore[assignment]
4245
# Load publishers YAML file and convert to dictionary of Publisher instances
43-
publishers_raw = yaml.safe_load(open(publishers, "r"))
44-
publishers_ = {
46+
publishers_raw: list[Dict[str, Any]] = yaml.safe_load(open(publishers, "r")) # type: ignore[assignment]
47+
publishers_: Dict[str, Publisher] = {
4548
publisher["key"]: Publisher(
4649
key=publisher["key"],
4750
name=publisher["name"],
@@ -53,7 +56,7 @@ def run(config, publishers, coingecko_mapping, prometheus_port):
5356
)
5457
for publisher in publishers_raw
5558
}
56-
coingecko_mapping_ = yaml.safe_load(open(coingecko_mapping, "r"))
59+
coingecko_mapping_: Dict[str, Any] = yaml.safe_load(open(coingecko_mapping, "r")) # type: ignore[assignment]
5760
observer = Observer(
5861
config_,
5962
publishers_,
@@ -62,7 +65,7 @@ def run(config, publishers, coingecko_mapping, prometheus_port):
6265

6366
start_http_server(int(prometheus_port))
6467

65-
async def main():
68+
async def main() -> None:
6669
asyncio.create_task(start_health_server())
6770
await observer.run()
6871

pyth_observer/coingecko.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, TypedDict
1+
from typing import Any, Dict, TypedDict
22

33
from loguru import logger
44
from pycoingecko import CoinGeckoAPI
@@ -15,7 +15,9 @@ class Symbol(TypedDict):
1515
# However prices are updated every 1-10 minutes: https://www.coingecko.com/en/faq
1616
# Hence we only have to query once every minute.
1717
@throttle(rate_limit=1, period=60)
18-
async def get_coingecko_prices(mapping: Dict[str, Symbol]):
18+
async def get_coingecko_prices(
19+
mapping: Dict[str, Symbol],
20+
) -> Dict[str, Dict[str, Any]]:
1921
inverted_mapping = {mapping[x]["api"]: x for x in mapping}
2022
ids = [mapping[x]["api"] for x in mapping]
2123

pyth_observer/crosschain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class CrosschainPrice(TypedDict):
1616

1717

1818
class CrosschainPriceObserver:
19-
def __init__(self, url):
19+
def __init__(self, url: str) -> None:
2020
self.url = url
2121
self.valid = self.is_endpoint_valid()
2222

0 commit comments

Comments
 (0)